Package aloha :: Module aloha_lib
[hide private]
[frames] | no frames]

Source Code for Module aloha.aloha_lib

   1  ################################################################################ 
   2  # 
   3  # Copyright (c) 2010 The MadGraph5_aMC@NLO Development team and Contributors 
   4  # 
   5  # This file is a part of the MadGraph5_aMC@NLO project, an application which  
   6  # automatically generates Feynman diagrams and matrix elements for arbitrary 
   7  # high-energy processes in the Standard Model and beyond. 
   8  # 
   9  # It is subject to the MadGraph5_aMC@NLO license which should accompany this  
  10  # distribution. 
  11  # 
  12  # For more information, visit madgraph.phys.ucl.ac.be and amcatnlo.web.cern.ch 
  13  # 
  14  ################################################################################ 
  15  ##   Diagram of Class 
  16  ## 
  17  ##    Variable (vartype:0)<--- ScalarVariable  
  18  ##                          | 
  19  ##                          +- LorentzObject  
  20  ##                                 
  21  ## 
  22  ##    list <--- AddVariable (vartype :1)    
  23  ##            
  24  ##    array <--- MultVariable  <--- MultLorentz (vartype:2)  
  25  ##            
  26  ##    list <--- LorentzObjectRepresentation (vartype :4) <-- ConstantObject 
  27  ##                                                               (vartype:5) 
  28  ## 
  29  ##    FracVariable (vartype:3) 
  30  ## 
  31  ##    MultContainer (vartype:6) 
  32  ## 
  33  ################################################################################ 
  34  ## 
  35  ##   Variable is in fact Factory wich adds a references to the variable name 
  36  ##   Into the KERNEL (Of class Computation) instantiate a real variable object 
  37  ##   (of class C_Variable, DVariable for complex/real) and return a MUltVariable 
  38  ##   with a single element. 
  39  ## 
  40  ##   Lorentz Object works in the same way. 
  41  ## 
  42  ################################################################################ 
  43   
  44   
  45  from __future__ import division 
  46  from array import array 
  47  import collections 
  48  from fractions import Fraction 
  49  import numbers 
  50  import re 
  51  import aloha # define mode of writting 
52 53 -class defaultdict(collections.defaultdict):
54
55 - def __call__(self, *args):
56 return defaultdict(int)
57
58 -class Computation(dict):
59 """ a class to encapsulate all computation. Limit side effect """ 60
61 - def __init__(self):
62 self.objs = [] 63 self.use_tag = set() 64 self.id = -1 65 self.reduced_expr = {} 66 self.fct_expr = {} 67 self.reduced_expr2 = {} 68 self.inverted_fct = {} 69 self.has_pi = False # logical to check if pi is used in at least one fct 70 self.unknow_fct = [] 71 dict.__init__(self)
72
73 - def clean(self):
74 self.__init__() 75 self.clear()
76
77 - def add(self, name, obj):
78 self.id += 1 79 self.objs.append(obj) 80 self[name] = self.id 81 return self.id
82
83 - def get(self, name):
84 return self.objs[self[name]]
85
86 - def add_tag(self, tag):
87 self.use_tag.update(tag)
88
89 - def get_ids(self, variables):
90 """return the list of identification number associate to the 91 given variables names. If a variable didn't exists, create it (in complex). 92 """ 93 out = [] 94 for var in variables: 95 try: 96 id = self[var] 97 except KeyError: 98 assert var not in ['M','W'] 99 id = Variable(var).get_id() 100 out.append(id) 101 return out
102 103
104 - def add_expression_contraction(self, expression):
105 106 str_expr = str(expression) 107 if str_expr in self.reduced_expr: 108 out, tag = self.reduced_expr[str_expr] 109 self.add_tag((tag,)) 110 return out 111 if expression == 0: 112 return 0 113 new_2 = expression.simplify() 114 if new_2 == 0: 115 return 0 116 # Add a new variable 117 tag = 'TMP%s' % len(self.reduced_expr) 118 new = Variable(tag) 119 self.reduced_expr[str_expr] = [new, tag] 120 new_2 = new_2.factorize() 121 self.reduced_expr2[tag] = new_2 122 self.add_tag((tag,)) 123 #self.unknow_fct = [] 124 #return expression 125 return new
126 127 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot', 128 'theta_function', 'exp']
129 - def add_function_expression(self, fct_tag, *args):
130 131 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or 132 (fct_tag, len(args)) in self.unknow_fct): 133 self.unknow_fct.append( (fct_tag, len(args)) ) 134 135 argument = [] 136 for expression in args: 137 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)): 138 try: 139 expr = expression.expand().get_rep([0]) 140 except KeyError, error: 141 if error.args != ((0,),): 142 raise 143 else: 144 raise aloha.ALOHAERROR, '''Error in input format. 145 Argument of function (or denominator) should be scalar. 146 We found %s''' % expression 147 new = expr.simplify() 148 new = expr.factorize() 149 argument.append(new) 150 else: 151 argument.append(expression) 152 for arg in argument: 153 val = re.findall(r'''\bFCT(\d*)\b''', str(arg)) 154 for v in val: 155 self.add_tag(('FCT%s' % v,)) 156 157 if str(fct_tag)+str(argument) in self.inverted_fct: 158 tag = self.inverted_fct[str(fct_tag)+str(argument)] 159 v = tag.split('(')[1][:-1] 160 self.add_tag(('FCT%s' % v,)) 161 return tag 162 else: 163 id = len(self.fct_expr) 164 tag = 'FCT%s' % id 165 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id 166 self.fct_expr[tag] = (fct_tag, argument) 167 self.reduced_expr2[tag] = (fct_tag, argument) 168 self.add_tag((tag,)) 169 return 'FCT(%s)' % id
170 171 KERNEL = Computation()
172 173 #=============================================================================== 174 # AddVariable 175 #=============================================================================== 176 -class AddVariable(list):
177 """ A list of Variable/ConstantObject/... This object represent the operation 178 between those object.""" 179 180 #variable to fastenize class recognition 181 vartype = 1 182
183 - def __init__(self, old_data=[], prefactor=1):
184 """ initialization of the object with default value """ 185 186 self.prefactor = prefactor 187 #self.tag = set() 188 list.__init__(self, old_data)
189
190 - def simplify(self):
191 """ apply rule of simplification """ 192 193 # deal with one length object 194 if len(self) == 1: 195 return self.prefactor * self[0].simplify() 196 constant = 0 197 items = {} 198 pos = -1 199 for term in self[:]: 200 pos += 1 # current position in the real self 201 if not hasattr(term, 'vartype'): 202 if isinstance(term, dict): 203 # allow term of type{(0,):x} 204 assert term.values() == [0] 205 term = term[(0,)] 206 constant += term 207 del self[pos] 208 pos -= 1 209 continue 210 tag = tuple(term.sort()) 211 if tag in items: 212 orig_prefac = items[tag].prefactor # to assume to zero 0.33333 -0.3333 213 items[tag].prefactor += term.prefactor 214 if items[tag].prefactor and \ 215 abs(items[tag].prefactor) / (abs(orig_prefac)+abs(term.prefactor)) < 1e-8: 216 items[tag].prefactor = 0 217 del self[pos] 218 pos -=1 219 else: 220 items[tag] = term.__class__(term, term.prefactor) 221 self[pos] = items[tag] 222 223 # get the optimized prefactor 224 countprefact = defaultdict(int) 225 nbplus, nbminus = 0,0 226 if constant not in [0, 1,-1]: 227 countprefact[constant] += 1 228 if constant.real + constant.imag > 0: 229 nbplus += 1 230 else: 231 nbminus += 1 232 233 for var in items.values(): 234 if var.prefactor == 0: 235 self.remove(var) 236 else: 237 nb = var.prefactor 238 if nb in [1,-1]: 239 continue 240 countprefact[abs(nb)] +=1 241 if nb.real + nb.imag > 0: 242 nbplus += 1 243 else: 244 nbminus += 1 245 if countprefact and max(countprefact.values()) >1: 246 fact_prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 247 else: 248 fact_prefactor = 1 249 if nbplus < nbminus: 250 fact_prefactor *= -1 251 self.prefactor *= fact_prefactor 252 253 if fact_prefactor != 1: 254 for i,a in enumerate(self): 255 try: 256 a.prefactor /= fact_prefactor 257 except AttributeError: 258 self[i] /= fact_prefactor 259 260 if constant: 261 self.append(constant/ fact_prefactor ) 262 263 # deal with one/zero length object 264 varlen = len(self) 265 if varlen == 1: 266 if hasattr(self[0], 'vartype'): 267 return self.prefactor * self[0].simplify() 268 else: 269 #self[0] is a number 270 return self.prefactor * self[0] 271 elif varlen == 0: 272 return 0 #ConstantObject() 273 return self
274
275 - def split(self, variables_id):
276 """return a dict with the key being the power associated to each variables 277 and the value being the object remaining after the suppression of all 278 the variable""" 279 280 out = defaultdict(int) 281 for obj in self: 282 for key, value in obj.split(variables_id).items(): 283 out[key] += self.prefactor * value 284 return out
285
286 - def contains(self, variables):
287 """returns true if one of the variables is in the expression""" 288 289 return any((v in obj for obj in self for v in variables ))
290 291
292 - def get_all_var_names(self):
293 294 out = [] 295 for term in self: 296 if hasattr(term, 'get_all_var_names'): 297 out += term.get_all_var_names() 298 return out
299 300 301
302 - def replace(self, id, expression):
303 """replace one object (identify by his id) by a given expression. 304 Note that expression cann't be zero. 305 Note that this should be canonical form (this should contains ONLY 306 MULTVARIABLE) --so this should be called before a factorize. 307 """ 308 new = self.__class__() 309 310 for obj in self: 311 assert isinstance(obj, MultVariable) 312 tmp = obj.replace(id, expression) 313 new += tmp 314 new.prefactor = self.prefactor 315 return new
316 317
318 - def expand(self, veto=[]):
319 """Pass from High level object to low level object""" 320 321 if not self: 322 return self 323 if self.prefactor == 1: 324 new = self[0].expand(veto) 325 else: 326 new = self.prefactor * self[0].expand(veto) 327 328 for item in self[1:]: 329 if self.prefactor == 1: 330 try: 331 new += item.expand(veto) 332 except AttributeError: 333 new = new + item 334 335 else: 336 new += (self.prefactor) * item.expand(veto) 337 return new
338
339 - def __mul__(self, obj):
340 """define the multiplication of 341 - a AddVariable with a number 342 - a AddVariable with an AddVariable 343 other type of multiplication are define via the symmetric operation base 344 on the obj class.""" 345 346 347 if not hasattr(obj, 'vartype'): # obj is a number 348 if not obj: 349 return 0 350 return self.__class__(self, self.prefactor*obj) 351 elif obj.vartype == 1: # obj is an AddVariable 352 new = self.__class__([],self.prefactor * obj.prefactor) 353 new[:] = [i*j for i in self for j in obj] 354 return new 355 else: 356 #force the program to look at obj + self 357 return NotImplemented
358
359 - def __imul__(self, obj):
360 """define the multiplication of 361 - a AddVariable with a number 362 - a AddVariable with an AddVariable 363 other type of multiplication are define via the symmetric operation base 364 on the obj class.""" 365 366 if not hasattr(obj, 'vartype'): # obj is a number 367 if not obj: 368 return 0 369 self.prefactor *= obj 370 return self 371 elif obj.vartype == 1: # obj is an AddVariable 372 new = self.__class__([], self.prefactor * obj.prefactor) 373 new[:] = [i*j for i in self for j in obj] 374 return new 375 else: 376 #force the program to look at obj + self 377 return NotImplemented
378
379 - def __neg__(self):
380 self.prefactor *= -1 381 return self
382
383 - def __add__(self, obj):
384 """Define all the different addition.""" 385 386 if not hasattr(obj, 'vartype'): 387 if not obj: # obj is zero 388 return self 389 new = self.__class__(self, self.prefactor) 390 new.append(obj/self.prefactor) 391 return new 392 elif obj.vartype == 2: # obj is a MultVariable 393 new = AddVariable(self, self.prefactor) 394 if self.prefactor == 1: 395 new.append(obj) 396 else: 397 new.append((1/self.prefactor)*obj) 398 return new 399 elif obj.vartype == 1: # obj is a AddVariable 400 new = AddVariable(self, self.prefactor) 401 for item in obj: 402 new.append(obj.prefactor/self.prefactor * item) 403 return new 404 else: 405 #force to look at obj + self 406 return NotImplemented
407
408 - def __iadd__(self, obj):
409 """Define all the different addition.""" 410 411 if not hasattr(obj, 'vartype'): 412 if not obj: # obj is zero 413 return self 414 self.append(obj/self.prefactor) 415 return self 416 elif obj.vartype == 2: # obj is a MultVariable 417 if self.prefactor == 1: 418 self.append(obj) 419 else: 420 self.append((1/self.prefactor)*obj) 421 return self 422 elif obj.vartype == 1: # obj is a AddVariable 423 for item in obj: 424 self.append(obj.prefactor/self.prefactor * item) 425 return self 426 else: 427 #force to look at obj + self 428 return NotImplemented
429
430 - def __sub__(self, obj):
431 return self + (-1) * obj
432
433 - def __rsub__(self, obj):
434 return (-1) * self + obj
435 436 __radd__ = __add__ 437 __rmul__ = __mul__ 438 439
440 - def __div__(self, obj):
441 return self.__mul__(1/obj)
442 443 __truediv__ = __div__ 444
445 - def __rdiv__(self, obj):
446 return self.__rmult__(1/obj)
447
448 - def __str__(self):
449 text = '' 450 if self.prefactor != 1: 451 text += str(self.prefactor) + ' * ' 452 text += '( ' 453 text += ' + '.join([str(item) for item in self]) 454 text += ' )' 455 return text
456
457 - def count_term(self):
458 # Count the number of appearance of each variable and find the most 459 #present one in order to factorize her 460 count = defaultdict(int) 461 correlation = defaultdict(defaultdict(int)) 462 for i,term in enumerate(self): 463 try: 464 set_term = set(term) 465 except TypeError: 466 #constant term 467 continue 468 for val1 in set_term: 469 count[val1] +=1 470 # allow to find optimized factorization for identical count 471 for val2 in set_term: 472 correlation[val1][val2] += 1 473 474 maxnb = max(count.values()) if count else 0 475 possibility = [v for v,val in count.items() if val == maxnb] 476 if maxnb == 1: 477 return 1, None 478 elif len(possibility) == 1: 479 return maxnb, possibility[0] 480 #import random 481 #return maxnb, random.sample(possibility,1)[0] 482 483 #return maxnb, possibility[0] 484 max_wgt, maxvar = 0, None 485 for var in possibility: 486 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var]) 487 if wgt > max_wgt: 488 maxvar = var 489 max_wgt = wgt 490 str_maxvar = str(KERNEL.objs[var]) 491 elif wgt == max_wgt: 492 # keep the one with the lowest string expr 493 new_str = str(KERNEL.objs[var]) 494 if new_str < str_maxvar: 495 maxvar = var 496 str_maxvar = new_str 497 return maxnb, maxvar
498
499 - def factorize(self):
500 """ try to factorize as much as possible the expression """ 501 502 max, maxvar = self.count_term() 503 if max <= 1: 504 #no factorization possible 505 return self 506 else: 507 # split in MAXVAR * NEWADD + CONSTANT 508 newadd = AddVariable() 509 constant = AddVariable() 510 #fill NEWADD and CONSTANT 511 for term in self: 512 try: 513 term.remove(maxvar) 514 except Exception: 515 constant.append(term) 516 else: 517 if len(term): 518 newadd.append(term) 519 else: 520 newadd.append(term.prefactor) 521 newadd = newadd.factorize() 522 523 # optimize the prefactor 524 if isinstance(newadd, AddVariable): 525 countprefact = defaultdict(int) 526 nbplus, nbminus = 0,0 527 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]: 528 countprefact[abs(nb)] +=1 529 if nb.real + nb.imag > 0: 530 nbplus += 1 531 else: 532 nbminus += 1 533 534 newadd.prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 535 if nbplus < nbminus: 536 newadd.prefactor *= -1 537 if newadd.prefactor != 1: 538 for i,a in enumerate(newadd): 539 try: 540 a.prefactor /= newadd.prefactor 541 except AttributeError: 542 newadd[i] /= newadd.prefactor 543 544 545 if len(constant) > 1: 546 constant = constant.factorize() 547 elif constant: 548 constant = constant[0] 549 else: 550 out = MultContainer([KERNEL.objs[maxvar], newadd]) 551 out.prefactor = self.prefactor 552 if newadd.prefactor != 1: 553 out.prefactor *= newadd.prefactor 554 newadd.prefactor = 1 555 return out 556 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant], 557 self.prefactor) 558 return out
559
560 -class MultContainer(list):
561 562 vartype = 6 563
564 - def __init__(self,*args):
565 self.prefactor =1 566 list.__init__(self, *args)
567
568 - def __str__(self):
569 """ String representation """ 570 if self.prefactor !=1: 571 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self])) 572 else: 573 text = '(%s)' % (' * '.join([str(t) for t in self])) 574 return text
575
576 - def factorize(self):
577 self[:] = [term.factorize() for term in self]
578
579 580 -class MultVariable(array):
581 """ A list of Variable with multiplication as operator between themselves. 582 Represented by array for speed optimization 583 """ 584 vartype=2 585 addclass = AddVariable 586
587 - def __new__(cls, old=[], prefactor=1):
588 return array.__new__(cls, 'i', old)
589 590
591 - def __init__(self, old=[], prefactor=1):
592 """ initialization of the object with default value """ 593 #array.__init__(self, 'i', old) <- done already in new !! 594 self.prefactor = prefactor 595 assert isinstance(self.prefactor, (float,int,long,complex))
596
597 - def get_id(self):
598 assert len(self) == 1 599 return self[0]
600
601 - def sort(self):
602 a = list(self) 603 a.sort() 604 self[:] = array('i',a) 605 return self
606
607 - def simplify(self):
608 """ simplify the product""" 609 if not len(self): 610 return self.prefactor 611 return self
612
613 - def split(self, variables_id):
614 """return a dict with the key being the power associated to each variables 615 and the value being the object remaining after the suppression of all 616 the variable""" 617 618 key = tuple([self.count(i) for i in variables_id]) 619 arg = [id for id in self if id not in variables_id] 620 self[:] = array('i', arg) 621 return SplitCoefficient([(key,self)])
622
623 - def replace(self, id, expression):
624 """replace one object (identify by his id) by a given expression. 625 Note that expression cann't be zero. 626 """ 627 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult' 628 629 if expression.vartype == 1: # AddVariable 630 nb = self.count(id) 631 if not nb: 632 return self 633 for i in range(nb): 634 self.remove(id) 635 new = self 636 for i in range(nb): 637 new *= expression 638 return new 639 elif expression.vartype == 2: # MultLorentz 640 # be carefull about A -> A * B 641 nb = self.count(id) 642 for i in range(nb): 643 self.remove(id) 644 self.__imul__(expression) 645 return self 646 # elif expression.vartype == 0: # Variable 647 # new_id = expression.id 648 # assert new_id != id 649 # while 1: 650 # try: 651 # self.remove(id) 652 # except ValueError: 653 # break 654 # else: 655 # self.append(new_id) 656 # return self 657 else: 658 raise Exception, 'Cann\'t replace a Variable by %s' % type(expression)
659 660
661 - def get_all_var_names(self):
662 """return the list of variable used in this multiplication""" 663 return ['%s' % KERNEL.objs[n] for n in self]
664 665 666 667 #Defining rule of Multiplication
668 - def __mul__(self, obj):
669 """Define the multiplication with different object""" 670 671 if not hasattr(obj, 'vartype'): # should be a number 672 if obj: 673 return self.__class__(self, obj*self.prefactor) 674 else: 675 return 0 676 elif obj.vartype == 1: # obj is an AddVariable 677 new = obj.__class__([], self.prefactor*obj.prefactor) 678 old, self.prefactor = self.prefactor, 1 679 new[:] = [self * term for term in obj] 680 self.prefactor = old 681 return new 682 elif obj.vartype == 4: 683 return NotImplemented 684 685 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
686 687 __rmul__ = __mul__ 688
689 - def __imul__(self, obj):
690 """Define the multiplication with different object""" 691 692 if not hasattr(obj, 'vartype'): # should be a number 693 if obj: 694 self.prefactor *= obj 695 return self 696 else: 697 return 0 698 elif obj.vartype == 1: # obj is an AddVariable 699 new = obj.__class__([], self.prefactor * obj.prefactor) 700 self.prefactor = 1 701 new[:] = [self * term for term in obj] 702 return new 703 elif obj.vartype == 4: 704 return NotImplemented 705 706 self.prefactor *= obj.prefactor 707 return array.__iadd__(self, obj)
708
709 - def __pow__(self,value):
710 out = 1 711 for i in range(value): 712 out *= self 713 return out
714 715
716 - def __add__(self, obj):
717 """ define the adition with different object""" 718 719 if not obj: 720 return self 721 elif not hasattr(obj, 'vartype') or obj.vartype == 2: 722 new = self.addclass([self, obj]) 723 return new 724 else: 725 #call the implementation of addition implemented in obj 726 return NotImplemented
727 __radd__ = __add__ 728 __iadd__ = __add__ 729
730 - def __sub__(self, obj):
731 return self + (-1) * obj
732
733 - def __neg__(self):
734 self.prefactor *=-1 735 return self
736
737 - def __rsub__(self, obj):
738 return (-1) * self + obj
739
740 - def __idiv__(self,obj):
741 """ ONLY NUMBER DIVISION ALLOWED""" 742 assert not hasattr(obj, 'vartype') 743 self.prefactor /= obj 744 return self
745 746 __div__ = __idiv__ 747 __truediv__ = __div__ 748 749
750 - def __str__(self):
751 """ String representation """ 752 t = ['%s' % KERNEL.objs[n] for n in self] 753 if self.prefactor != 1: 754 text = '(%s * %s)' % (self.prefactor,' * '.join(t)) 755 else: 756 text = '(%s)' % (' * '.join(t)) 757 return text
758 759 __rep__ = __str__ 760
761 - def factorize(self):
762 return self
763
764 765 #=============================================================================== 766 # FactoryVar 767 #=============================================================================== 768 -class C_Variable(str):
769 vartype=0 770 type = 'complex'
771
772 -class R_Variable(str):
773 vartype=0 774 type = 'double'
775
776 -class ExtVariable(str):
777 vartype=0 778 type = 'parameter'
779
780 781 -class FactoryVar(object):
782 """This is the standard object for all the variable linked to expression. 783 """ 784 mult_class = MultVariable # The class for the multiplication 785
786 - def __new__(cls, name, baseclass, *args):
787 """Factory class return a MultVariable.""" 788 789 if name in KERNEL: 790 return cls.mult_class([KERNEL[name]]) 791 else: 792 obj = baseclass(name, *args) 793 id = KERNEL.add(name, obj) 794 obj.id = id 795 return cls.mult_class([id])
796
797 -class Variable(FactoryVar):
798
799 - def __new__(self, name, type=C_Variable):
800 return FactoryVar(name, type)
801
802 -class DVariable(FactoryVar):
803
804 - def __new__(self, name):
805 806 if aloha.complex_mass: 807 #some parameter are pass to complex 808 if name[0] in ['M','W'] or name.startswith('OM'): 809 return FactoryVar(name, C_Variable) 810 if aloha.loop_mode and name.startswith('P'): 811 return FactoryVar(name, C_Variable) 812 #Normal case: 813 return FactoryVar(name, R_Variable)
814
815 816 817 818 #=============================================================================== 819 # Object for Analytical Representation of Lorentz object (not scalar one) 820 #=============================================================================== 821 822 823 #=============================================================================== 824 # MultLorentz 825 #=============================================================================== 826 -class MultLorentz(MultVariable):
827 """Specific class for LorentzObject Multiplication""" 828 829 add_class = AddVariable # Define which class describe the addition 830
831 - def find_lorentzcontraction(self):
832 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 833 the contraction in this Multiplication.""" 834 835 out = {} 836 len_mult = len(self) 837 # Loop over the element 838 for i, fact in enumerate(self): 839 # and over the indices of this element 840 for j in range(len(fact.lorentz_ind)): 841 # in order to compare with the other element of the multiplication 842 for k in range(i+1,len_mult): 843 fact2 = self[k] 844 try: 845 l = fact2.lorentz_ind.index(fact.lorentz_ind[j]) 846 except Exception: 847 pass 848 else: 849 out[(i, j)] = (k, l) 850 out[(k, l)] = (i, j) 851 return out
852
853 - def find_spincontraction(self):
854 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 855 the contraction in this Multiplication.""" 856 857 out = {} 858 len_mult = len(self) 859 # Loop over the element 860 for i, fact in enumerate(self): 861 # and over the indices of this element 862 for j in range(len(fact.spin_ind)): 863 # in order to compare with the other element of the multiplication 864 for k in range(i+1, len_mult): 865 fact2 = self[k] 866 try: 867 l = fact2.spin_ind.index(fact.spin_ind[j]) 868 except Exception: 869 pass 870 else: 871 out[(i, j)] = (k, l) 872 out[(k, l)] = (i, j) 873 874 return out
875
876 - def neighboor(self, home):
877 """return one variable which are contracted with var and not yet expanded""" 878 879 for var in self.unused: 880 obj = KERNEL.objs[var] 881 if obj.has_component(home.lorentz_ind, home.spin_ind): 882 return obj 883 return None
884 885 886 887
888 - def expand(self, veto=[]):
889 """ expand each part of the product and combine them. 890 Try to use a smart order in order to minimize the number of uncontracted indices. 891 Veto forbids the use of sub-expression if it contains some of the variable in the 892 expression. Veto contains the id of the vetoed variables 893 """ 894 895 self.unused = self[:] # list of not expanded 896 # made in a list the interesting starting point for the computation 897 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first] 898 product_term = [] #store result of intermediate chains 899 current = None # current point in the working chain 900 901 while self.unused: 902 #Loop untill we have expand everything 903 if not current: 904 # First we need to have a starting point 905 try: 906 # look in priority in basic_end_point (P/S/fermion/...) 907 current = basic_end_point.pop() 908 except Exception: 909 #take one of the remaining 910 current = self.unused.pop() 911 else: 912 #check that this one is not already use 913 if current not in self.unused: 914 current = None 915 continue 916 #remove of the unuse (usualy done in the pop) 917 self.unused.remove(current) 918 cur_obj = KERNEL.objs[current] 919 # initialize the new chain 920 product_term.append(cur_obj.expand()) 921 922 # We have a point -> find the next one 923 var_obj = self.neighboor(product_term[-1]) 924 # provide one term which is contracted with current and which is not 925 #yet expanded. 926 if var_obj: 927 product_term[-1] *= var_obj.expand() 928 cur_obj = var_obj 929 self.unused.remove(cur_obj.id) 930 continue 931 932 current = None 933 934 935 # Multiply all those current 936 # For Fermion/Vector only one can carry index. 937 out = self.prefactor 938 for fact in product_term[:]: 939 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []: 940 scalar = fact.get_rep([0]) 941 if hasattr(scalar, 'vartype') and scalar.vartype == 1: 942 if not veto or not scalar.contains(veto): 943 scalar = scalar.simplify() 944 prefactor = 1 945 946 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]: 947 prefactor = scalar.prefactor 948 scalar.prefactor = 1 949 new = KERNEL.add_expression_contraction(scalar) 950 fact.set_rep([0], prefactor * new) 951 out *= fact 952 return out
953
954 - def __copy__(self):
955 """ create a shadow copy """ 956 new = MultLorentz(self) 957 new.prefactor = self.prefactor 958 return new
959
960 #=============================================================================== 961 # LorentzObject 962 #=============================================================================== 963 -class LorentzObject(object):
964 """ A symbolic Object for All Helas object. All Helas Object Should 965 derivated from this class""" 966 967 contract_first = 0 968 mult_class = MultLorentz # The class for the multiplication 969 add_class = AddVariable # The class for the addition 970
971 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
972 """ initialization of the object with default value """ 973 assert isinstance(lor_ind, list) 974 assert isinstance(spin_ind, list) 975 976 self.name = name 977 self.lorentz_ind = lor_ind 978 self.spin_ind = spin_ind 979 KERNEL.add_tag(set(tags))
980
981 - def expand(self):
982 """Expand the content information into LorentzObjectRepresentation.""" 983 984 try: 985 return self.representation 986 except Exception: 987 self.create_representation() 988 return self.representation
989
990 - def create_representation(self):
991 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
992
993 - def has_component(self, lor_list, spin_list):
994 """check if this Lorentz Object have some of those indices""" 995 996 if any([id in self.lorentz_ind for id in lor_list]) or \ 997 any([id in self.spin_ind for id in spin_list]): 998 return True
999 1000 1001
1002 - def __str__(self):
1003 return '%s' % self.name
1004
1005 -class FactoryLorentz(FactoryVar):
1006 """ A symbolic Object for All Helas object. All Helas Object Should 1007 derivated from this class""" 1008 1009 mult_class = MultLorentz # The class for the multiplication 1010 object_class = LorentzObject # Define How to create the basic object. 1011
1012 - def __new__(cls, *args):
1013 name = cls.get_unique_name(*args) 1014 return FactoryVar.__new__(cls, name, cls.object_class, *args)
1015 1016 @classmethod
1017 - def get_unique_name(cls, *args):
1018 """default way to have a unique name""" 1019 return '_L_%(class)s_%(args)s' % \ 1020 {'class':cls.__name__, 1021 'args': '_'.join(args) 1022 }
1023
1024 1025 #=============================================================================== 1026 # LorentzObjectRepresentation 1027 #=============================================================================== 1028 -class LorentzObjectRepresentation(dict):
1029 """A concrete representation of the LorentzObject.""" 1030 1031 vartype = 4 # Optimization for instance recognition 1032
1033 - class LorentzObjectRepresentationError(Exception):
1034 """Specify error for LorentzObjectRepresentation"""
1035
1036 - def __init__(self, representation, lorentz_indices, spin_indices):
1037 """ initialize the lorentz object representation""" 1038 1039 self.lorentz_ind = lorentz_indices #lorentz indices 1040 self.nb_lor = len(lorentz_indices) #their number 1041 self.spin_ind = spin_indices #spin indices 1042 self.nb_spin = len(spin_indices) #their number 1043 self.nb_ind = self.nb_lor + self.nb_spin #total number of indices 1044 1045 #store the representation 1046 if self.lorentz_ind or self.spin_ind: 1047 dict.__init__(self, representation) 1048 elif isinstance(representation,dict): 1049 if len(representation) == 0: 1050 self[(0,)] = 0 1051 elif len(representation) == 1 and (0,) in representation: 1052 self[(0,)] = representation[(0,)] 1053 else: 1054 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1055 else: 1056 if isinstance(representation,dict): 1057 try: 1058 self[(0,)] = representation[(0,)] 1059 except Exception: 1060 if representation: 1061 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1062 else: 1063 self[(0,)] = 0 1064 else: 1065 self[(0,)] = representation
1066
1067 - def __str__(self):
1068 """ string representation """ 1069 text = 'lorentz index :' + str(self.lorentz_ind) + '\n' 1070 text += 'spin index :' + str(self.spin_ind) + '\n' 1071 #text += 'other info ' + str(self.tag) + '\n' 1072 for ind in self.listindices(): 1073 ind = tuple(ind) 1074 text += str(ind) + ' --> ' 1075 text += str(self.get_rep(ind)) + '\n' 1076 return text
1077
1078 - def get_rep(self, indices):
1079 """return the value/Variable associate to the indices""" 1080 return self[tuple(indices)]
1081
1082 - def set_rep(self, indices, value):
1083 """assign 'value' at the indices position""" 1084 1085 self[tuple(indices)] = value
1086
1087 - def listindices(self):
1088 """Return an iterator in order to be able to loop easily on all the 1089 indices of the object.""" 1090 return IndicesIterator(self.nb_ind)
1091 1092 @staticmethod
1093 - def get_mapping(l1,l2, switch_order=[]):
1094 shift = len(switch_order) 1095 for value in l1: 1096 try: 1097 index = l2.index(value) 1098 except Exception: 1099 raise LorentzObjectRepresentation.LorentzObjectRepresentationError( 1100 "Invalid addition. Object doen't have the same lorentz "+ \ 1101 "indices : %s != %s" % (l1, l2)) 1102 else: 1103 switch_order.append(shift + index) 1104 return switch_order
1105 1106
1107 - def __add__(self, obj, fact=1):
1108 1109 if not obj: 1110 return self 1111 1112 if not hasattr(obj, 'vartype'): 1113 assert self.lorentz_ind == [] 1114 assert self.spin_ind == [] 1115 new = self[(0,)] + obj * fact 1116 out = LorentzObjectRepresentation(new, [], []) 1117 return out 1118 1119 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1120 1121 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1122 # if the order of indices are different compute a mapping 1123 switch_order = [] 1124 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order) 1125 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order) 1126 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1127 else: 1128 # no mapping needed (define switch as identity) 1129 switch = lambda ind : (ind) 1130 1131 # Some sanity check 1132 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind)) 1133 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind)) 1134 1135 # define an empty representation 1136 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind) 1137 1138 # loop over all indices and fullfill the new object 1139 if fact == 1: 1140 for ind in self.listindices(): 1141 value = obj.get_rep(ind) + self.get_rep(switch(ind)) 1142 new.set_rep(ind, value) 1143 else: 1144 for ind in self.listindices(): 1145 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind) 1146 new.set_rep(ind, value) 1147 1148 return new
1149
1150 - def __iadd__(self, obj, fact=1):
1151 1152 if not obj: 1153 return self 1154 1155 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1156 1157 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1158 1159 # if the order of indices are different compute a mapping 1160 switch_order = [] 1161 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order) 1162 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order) 1163 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1164 else: 1165 # no mapping needed (define switch as identity) 1166 switch = lambda ind : (ind) 1167 1168 # Some sanity check 1169 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind)) 1170 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind)) 1171 1172 # loop over all indices and fullfill the new object 1173 if fact == 1: 1174 for ind in self.listindices(): 1175 self[tuple(ind)] += obj.get_rep(switch(ind)) 1176 else: 1177 for ind in self.listindices(): 1178 self[tuple(ind)] += fact * obj.get_rep(switch(ind)) 1179 return self
1180
1181 - def __sub__(self, obj):
1182 return self.__add__(obj, fact= -1)
1183
1184 - def __rsub__(self, obj):
1185 return obj.__add__(self, fact= -1)
1186
1187 - def __isub__(self, obj):
1188 return self.__add__(obj, fact= -1)
1189
1190 - def __neg__(self):
1191 self *= -1 1192 return self
1193
1194 - def __mul__(self, obj):
1195 """multiplication performing directly the einstein/spin sommation. 1196 """ 1197 1198 if not hasattr(obj, 'vartype'): 1199 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind) 1200 for ind in out.listindices(): 1201 out.set_rep(ind, obj * self.get_rep(ind)) 1202 return out 1203 1204 # Sanity Check 1205 assert(obj.__class__ == LorentzObjectRepresentation), \ 1206 '%s is not valid class for this operation' %type(obj) 1207 1208 # compute information on the status of the index (which are contracted/ 1209 #not contracted 1210 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \ 1211 obj.lorentz_ind) 1212 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \ 1213 obj.spin_ind) 1214 if not(sum_l_ind or sum_s_ind): 1215 # No contraction made a tensor product 1216 return self.tensor_product(obj) 1217 1218 # elsewher made a spin contraction 1219 # create an empty representation but with correct indices 1220 new_object = LorentzObjectRepresentation({}, l_ind, s_ind) 1221 #loop and fullfill the representation 1222 for indices in new_object.listindices(): 1223 #made a dictionary (pos -> index_value) for how call the object 1224 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind) 1225 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind) 1226 #add the new value 1227 new_object.set_rep(indices, \ 1228 self.contraction(obj, sum_l_ind, sum_s_ind, \ 1229 dict_l_ind, dict_s_ind)) 1230 1231 return new_object
1232 1233 __rmul__ = __mul__ 1234 __imul__ = __mul__ 1235
1236 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1237 """ make the Lorentz/spin contraction of object self and obj. 1238 l_sum/s_sum are the position of the sum indices 1239 l_dict/s_dict are dict given the value of the fix indices (indices->value) 1240 """ 1241 out = 0 # initial value for the output 1242 len_l = len(l_sum) #store len for optimization 1243 len_s = len(s_sum) # same 1244 1245 # loop over the possibility for the sum indices and update the dictionary 1246 # (indices->value) 1247 for l_value in IndicesIterator(len_l): 1248 l_dict.update(self.pass_ind_in_dict(l_value, l_sum)) 1249 for s_value in IndicesIterator(len_s): 1250 #s_dict_final = s_dict.copy() 1251 s_dict.update(self.pass_ind_in_dict(s_value, s_sum)) 1252 1253 #return the indices in the correct order 1254 self_ind = self.combine_indices(l_dict, s_dict) 1255 obj_ind = obj.combine_indices(l_dict, s_dict) 1256 1257 # call the object 1258 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind) 1259 1260 if factor: 1261 #compute the prefactor due to the lorentz contraction 1262 try: 1263 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0)) 1264 except Exception: 1265 factor *= (-1) ** (len(l_value) - l_value.count(0)) 1266 out += factor 1267 return out
1268
1269 - def tensor_product(self, obj):
1270 """ return the tensorial product of the object""" 1271 assert(obj.vartype == 4) #isinstance(obj, LorentzObjectRepresentation)) 1272 1273 new_object = LorentzObjectRepresentation({}, \ 1274 self.lorentz_ind + obj.lorentz_ind, \ 1275 self.spin_ind + obj.spin_ind) 1276 1277 #some shortcut 1278 lor1 = self.nb_lor 1279 lor2 = obj.nb_lor 1280 spin1 = self.nb_spin 1281 spin2 = obj.nb_spin 1282 1283 #define how to call build the indices first for the first object 1284 if lor1 == 0 == spin1: 1285 #special case for scalar 1286 selfind = lambda indices: [0] 1287 else: 1288 selfind = lambda indices: indices[:lor1] + \ 1289 indices[lor1 + lor2: lor1 + lor2 + spin1] 1290 1291 #then for the second 1292 if lor2 == 0 == spin2: 1293 #special case for scalar 1294 objind = lambda indices: [0] 1295 else: 1296 objind = lambda indices: indices[lor1: lor1 + lor2] + \ 1297 indices[lor1 + lor2 + spin1:] 1298 1299 # loop on the indices and assign the product 1300 for indices in new_object.listindices(): 1301 1302 fac1 = self.get_rep(tuple(selfind(indices))) 1303 fac2 = obj.get_rep(tuple(objind(indices))) 1304 new_object.set_rep(indices, fac1 * fac2) 1305 1306 return new_object
1307
1308 - def factorize(self):
1309 """Try to factorize each component""" 1310 for ind, fact in self.items(): 1311 if fact: 1312 self.set_rep(ind, fact.factorize()) 1313 1314 1315 return self
1316
1317 - def simplify(self):
1318 """Check if we can simplify the object (check for non treated Sum)""" 1319 1320 #Look for internal simplification 1321 for ind, term in self.items(): 1322 if hasattr(term, 'vartype'): 1323 self[ind] = term.simplify() 1324 #no additional simplification 1325 return self
1326 1327 @staticmethod
1328 - def compare_indices(list1, list2):
1329 """return two list, the first one contains the position of non summed 1330 index and the second one the position of summed index.""" 1331 #init object 1332 1333 # equivalent set call --slightly slower 1334 #return list(set(list1) ^ set(list2)), list(set(list1) & set(list2)) 1335 1336 1337 are_unique, are_sum = [], [] 1338 # loop over the first list and check if they are in the second list 1339 1340 for indice in list1: 1341 if indice in list2: 1342 are_sum.append(indice) 1343 else: 1344 are_unique.append(indice) 1345 # loop over the second list for additional unique item 1346 1347 for indice in list2: 1348 if indice not in are_sum: 1349 are_unique.append(indice) 1350 1351 # return value 1352 return are_unique, are_sum
1353 1354 @staticmethod
1355 - def pass_ind_in_dict(indices, key):
1356 """made a dictionary (pos -> index_value) for how call the object""" 1357 if not key: 1358 return {} 1359 out = {} 1360 for i, ind in enumerate(indices): 1361 out[key[i]] = ind 1362 return out
1363
1364 - def combine_indices(self, l_dict, s_dict):
1365 """return the indices in the correct order following the dicts rules""" 1366 1367 out = [] 1368 # First for the Lorentz indices 1369 for value in self.lorentz_ind: 1370 out.append(l_dict[value]) 1371 # Same for the spin 1372 for value in self.spin_ind: 1373 out.append(s_dict[value]) 1374 1375 return out
1376
1377 - def split(self, variables_id):
1378 """return a dict with the key being the power associated to each variables 1379 and the value being the object remaining after the suppression of all 1380 the variable""" 1381 1382 out = SplitCoefficient() 1383 zero_rep = {} 1384 for ind in self.listindices(): 1385 zero_rep[tuple(ind)] = 0 1386 1387 for ind in self.listindices(): 1388 for key, value in self.get_rep(ind).split(variables_id).items(): 1389 if key in out: 1390 out[key][tuple(ind)] += value 1391 else: 1392 out[key] = LorentzObjectRepresentation(dict(zero_rep), 1393 self.lorentz_ind, self.spin_ind) 1394 out[key][tuple(ind)] += value 1395 1396 return out
1397
1398 1399 1400 1401 #=============================================================================== 1402 # IndicesIterator 1403 #=============================================================================== 1404 -class IndicesIterator:
1405 """Class needed for the iterator""" 1406
1407 - def __init__(self, len):
1408 """ create an iterator looping over the indices of a list of len "len" 1409 with each value can take value between 0 and 3 """ 1410 1411 self.len = len # number of indices 1412 if len: 1413 # initialize the position. The first position is -1 due to the method 1414 #in place which start by rising an index before returning smtg 1415 self.data = [-1] + [0] * (len - 1) 1416 else: 1417 # Special case for Scalar object 1418 self.data = 0 1419 self.next = self.nextscalar
1420
1421 - def __iter__(self):
1422 return self
1423
1424 - def next(self):
1425 for i in range(self.len): 1426 if self.data[i] < 3: 1427 self.data[i] += 1 1428 return self.data 1429 else: 1430 self.data[i] = 0 1431 raise StopIteration
1432
1433 - def nextscalar(self):
1434 if self.data: 1435 raise StopIteration 1436 else: 1437 self.data = True 1438 return [0]
1439
1440 -class SplitCoefficient(dict):
1441
1442 - def __init__(self, *args, **opt):
1443 dict.__init__(self, *args, **opt) 1444 self.tag=set()
1445
1446 - def get_max_rank(self):
1447 """return the highest rank of the coefficient""" 1448 1449 return max([max(arg[:4]) for arg in self])
1450 1451 1452 if '__main__' ==__name__: 1453 1454 import cProfile
1455 - def create():
1456 for i in range(10000): 1457 LorentzObjectRepresentation.compare_indices(range(i%10),[4,3,5])
1458 1459 cProfile.run('create()') 1460