Package aloha :: Module aloha_lib
[hide private]
[frames] | no frames]

Source Code for Module aloha.aloha_lib

   1  ################################################################################ 
   2  # 
   3  # Copyright (c) 2010 The MadGraph5_aMC@NLO Development team and Contributors 
   4  # 
   5  # This file is a part of the MadGraph5_aMC@NLO project, an application which  
   6  # automatically generates Feynman diagrams and matrix elements for arbitrary 
   7  # high-energy processes in the Standard Model and beyond. 
   8  # 
   9  # It is subject to the MadGraph5_aMC@NLO license which should accompany this  
  10  # distribution. 
  11  # 
  12  # For more information, visit madgraph.phys.ucl.ac.be and amcatnlo.web.cern.ch 
  13  # 
  14  ################################################################################ 
  15  ##   Diagram of Class 
  16  ## 
  17  ##    Variable (vartype:0)<--- ScalarVariable  
  18  ##                          | 
  19  ##                          +- LorentzObject  
  20  ##                                 
  21  ## 
  22  ##    list <--- AddVariable (vartype :1)    
  23  ##            
  24  ##    array <--- MultVariable  <--- MultLorentz (vartype:2)  
  25  ##            
  26  ##    list <--- LorentzObjectRepresentation (vartype :4) <-- ConstantObject 
  27  ##                                                               (vartype:5) 
  28  ## 
  29  ##    FracVariable (vartype:3) 
  30  ## 
  31  ##    MultContainer (vartype:6) 
  32  ## 
  33  ################################################################################ 
  34  ## 
  35  ##   Variable is in fact Factory wich adds a references to the variable name 
  36  ##   Into the KERNEL (Of class Computation) instantiate a real variable object 
  37  ##   (of class C_Variable, DVariable for complex/real) and return a MUltVariable 
  38  ##   with a single element. 
  39  ## 
  40  ##   Lorentz Object works in the same way. 
  41  ## 
  42  ################################################################################ 
  43   
  44   
  45  from __future__ import division 
  46  from array import array 
  47  import collections 
  48  from fractions import Fraction 
  49  import numbers 
  50  import re 
  51  import aloha # define mode of writting 
52 53 -class defaultdict(collections.defaultdict):
54
55 - def __call__(self, *args):
56 return defaultdict(int)
57
58 -class Computation(dict):
59 """ a class to encapsulate all computation. Limit side effect """ 60
61 - def __init__(self):
62 self.objs = [] 63 self.use_tag = set() 64 self.id = -1 65 self.reduced_expr = {} 66 self.fct_expr = {} 67 self.reduced_expr2 = {} 68 self.inverted_fct = {} 69 self.has_pi = False # logical to check if pi is used in at least one fct 70 self.unknow_fct = [] 71 dict.__init__(self)
72
73 - def clean(self):
74 self.__init__() 75 self.clear()
76
77 - def add(self, name, obj):
78 self.id += 1 79 self.objs.append(obj) 80 self[name] = self.id 81 return self.id
82
83 - def get(self, name):
84 return self.objs[self[name]]
85
86 - def add_tag(self, tag):
87 self.use_tag.update(tag)
88
89 - def get_ids(self, variables):
90 """return the list of identification number associate to the 91 given variables names. If a variable didn't exists, create it (in complex). 92 """ 93 out = [] 94 for var in variables: 95 try: 96 id = self[var] 97 except KeyError: 98 assert var not in ['M','W'] 99 id = Variable(var).get_id() 100 out.append(id) 101 return out
102 103
104 - def add_expression_contraction(self, expression):
105 106 str_expr = str(expression) 107 if str_expr in self.reduced_expr: 108 out, tag = self.reduced_expr[str_expr] 109 self.add_tag((tag,)) 110 return out 111 if expression == 0: 112 return 0 113 new_2 = expression.simplify() 114 if new_2 == 0: 115 return 0 116 # Add a new variable 117 tag = 'TMP%s' % len(self.reduced_expr) 118 new = Variable(tag) 119 self.reduced_expr[str_expr] = [new, tag] 120 new_2 = new_2.factorize() 121 self.reduced_expr2[tag] = new_2 122 self.add_tag((tag,)) 123 #self.unknow_fct = [] 124 #return expression 125 return new
126 127 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot', 128 'theta_function', 'exp']
129 - def add_function_expression(self, fct_tag, *args):
130 131 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or 132 (fct_tag, len(args)) in self.unknow_fct): 133 self.unknow_fct.append( (fct_tag, len(args)) ) 134 135 argument = [] 136 for expression in args: 137 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)): 138 try: 139 expr = expression.expand().get_rep([0]) 140 except KeyError, error: 141 if error.args != ((0,),): 142 raise 143 else: 144 raise aloha.ALOHAERROR, '''Error in input format. 145 Argument of function (or denominator) should be scalar. 146 We found %s''' % expression 147 new = expr.simplify() 148 if not isinstance(new, numbers.Number): 149 new = new.factorize() 150 argument.append(new) 151 else: 152 argument.append(expression) 153 for arg in argument: 154 val = re.findall(r'''\bFCT(\d*)\b''', str(arg)) 155 for v in val: 156 self.add_tag(('FCT%s' % v,)) 157 158 # check if the function is a pure numerical function. 159 if (fct_tag.startswith('cmath.') or fct_tag in self.known_fct) and \ 160 all(isinstance(x, numbers.Number) for x in argument): 161 print "start the hack" 162 import cmath 163 if fct_tag.startswith('cmath.'): 164 module = '' 165 else: 166 module = 'cmath.' 167 try: 168 return str(eval("%s%s(%s)" % (module,fct_tag, ','.join(`x` for x in argument)))) 169 except Exception, error: 170 print error 171 print "cmath.%s(%s)" % (fct_tag, ','.join(`x` for x in argument)) 172 if str(fct_tag)+str(argument) in self.inverted_fct: 173 tag = self.inverted_fct[str(fct_tag)+str(argument)] 174 v = tag.split('(')[1][:-1] 175 self.add_tag(('FCT%s' % v,)) 176 return tag 177 else: 178 id = len(self.fct_expr) 179 tag = 'FCT%s' % id 180 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id 181 self.fct_expr[tag] = (fct_tag, argument) 182 self.reduced_expr2[tag] = (fct_tag, argument) 183 self.add_tag((tag,)) 184 return 'FCT(%s)' % id
185 186 KERNEL = Computation()
187 188 #=============================================================================== 189 # AddVariable 190 #=============================================================================== 191 -class AddVariable(list):
192 """ A list of Variable/ConstantObject/... This object represent the operation 193 between those object.""" 194 195 #variable to fastenize class recognition 196 vartype = 1 197
198 - def __init__(self, old_data=[], prefactor=1):
199 """ initialization of the object with default value """ 200 201 self.prefactor = prefactor 202 #self.tag = set() 203 list.__init__(self, old_data)
204
205 - def simplify(self):
206 """ apply rule of simplification """ 207 208 # deal with one length object 209 if len(self) == 1: 210 return self.prefactor * self[0].simplify() 211 constant = 0 212 items = {} 213 pos = -1 214 for term in self[:]: 215 pos += 1 # current position in the real self 216 if not hasattr(term, 'vartype'): 217 if isinstance(term, dict): 218 # allow term of type{(0,):x} 219 assert term.values() == [0] 220 term = term[(0,)] 221 constant += term 222 del self[pos] 223 pos -= 1 224 continue 225 tag = tuple(term.sort()) 226 if tag in items: 227 orig_prefac = items[tag].prefactor # to assume to zero 0.33333 -0.3333 228 items[tag].prefactor += term.prefactor 229 if items[tag].prefactor and \ 230 abs(items[tag].prefactor) / (abs(orig_prefac)+abs(term.prefactor)) < 1e-8: 231 items[tag].prefactor = 0 232 del self[pos] 233 pos -=1 234 else: 235 items[tag] = term.__class__(term, term.prefactor) 236 self[pos] = items[tag] 237 238 # get the optimized prefactor 239 countprefact = defaultdict(int) 240 nbplus, nbminus = 0,0 241 if constant not in [0, 1,-1]: 242 countprefact[constant] += 1 243 if constant.real + constant.imag > 0: 244 nbplus += 1 245 else: 246 nbminus += 1 247 248 for var in items.values(): 249 if var.prefactor == 0: 250 self.remove(var) 251 else: 252 nb = var.prefactor 253 if nb in [1,-1]: 254 continue 255 countprefact[abs(nb)] +=1 256 if nb.real + nb.imag > 0: 257 nbplus += 1 258 else: 259 nbminus += 1 260 if countprefact and max(countprefact.values()) >1: 261 fact_prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 262 else: 263 fact_prefactor = 1 264 if nbplus < nbminus: 265 fact_prefactor *= -1 266 self.prefactor *= fact_prefactor 267 268 if fact_prefactor != 1: 269 for i,a in enumerate(self): 270 try: 271 a.prefactor /= fact_prefactor 272 except AttributeError: 273 self[i] /= fact_prefactor 274 275 if constant: 276 self.append(constant/ fact_prefactor ) 277 278 # deal with one/zero length object 279 varlen = len(self) 280 if varlen == 1: 281 if hasattr(self[0], 'vartype'): 282 return self.prefactor * self[0].simplify() 283 else: 284 #self[0] is a number 285 return self.prefactor * self[0] 286 elif varlen == 0: 287 return 0 #ConstantObject() 288 return self
289
290 - def split(self, variables_id):
291 """return a dict with the key being the power associated to each variables 292 and the value being the object remaining after the suppression of all 293 the variable""" 294 295 out = defaultdict(int) 296 for obj in self: 297 for key, value in obj.split(variables_id).items(): 298 out[key] += self.prefactor * value 299 return out
300
301 - def contains(self, variables):
302 """returns true if one of the variables is in the expression""" 303 304 return any((v in obj for obj in self for v in variables ))
305 306
307 - def get_all_var_names(self):
308 309 out = [] 310 for term in self: 311 if hasattr(term, 'get_all_var_names'): 312 out += term.get_all_var_names() 313 return out
314 315 316
317 - def replace(self, id, expression):
318 """replace one object (identify by his id) by a given expression. 319 Note that expression cann't be zero. 320 Note that this should be canonical form (this should contains ONLY 321 MULTVARIABLE) --so this should be called before a factorize. 322 """ 323 new = self.__class__() 324 325 for obj in self: 326 assert isinstance(obj, MultVariable) 327 tmp = obj.replace(id, expression) 328 new += tmp 329 new.prefactor = self.prefactor 330 return new
331 332
333 - def expand(self, veto=[]):
334 """Pass from High level object to low level object""" 335 336 if not self: 337 return self 338 if self.prefactor == 1: 339 new = self[0].expand(veto) 340 else: 341 new = self.prefactor * self[0].expand(veto) 342 343 for item in self[1:]: 344 if self.prefactor == 1: 345 try: 346 new += item.expand(veto) 347 except AttributeError: 348 new = new + item 349 350 else: 351 new += (self.prefactor) * item.expand(veto) 352 return new
353
354 - def __mul__(self, obj):
355 """define the multiplication of 356 - a AddVariable with a number 357 - a AddVariable with an AddVariable 358 other type of multiplication are define via the symmetric operation base 359 on the obj class.""" 360 361 362 if not hasattr(obj, 'vartype'): # obj is a number 363 if not obj: 364 return 0 365 return self.__class__(self, self.prefactor*obj) 366 elif obj.vartype == 1: # obj is an AddVariable 367 new = self.__class__([],self.prefactor * obj.prefactor) 368 new[:] = [i*j for i in self for j in obj] 369 return new 370 else: 371 #force the program to look at obj + self 372 return NotImplemented
373
374 - def __imul__(self, obj):
375 """define the multiplication of 376 - a AddVariable with a number 377 - a AddVariable with an AddVariable 378 other type of multiplication are define via the symmetric operation base 379 on the obj class.""" 380 381 if not hasattr(obj, 'vartype'): # obj is a number 382 if not obj: 383 return 0 384 self.prefactor *= obj 385 return self 386 elif obj.vartype == 1: # obj is an AddVariable 387 new = self.__class__([], self.prefactor * obj.prefactor) 388 new[:] = [i*j for i in self for j in obj] 389 return new 390 else: 391 #force the program to look at obj + self 392 return NotImplemented
393
394 - def __neg__(self):
395 self.prefactor *= -1 396 return self
397
398 - def __add__(self, obj):
399 """Define all the different addition.""" 400 401 if not hasattr(obj, 'vartype'): 402 if not obj: # obj is zero 403 return self 404 new = self.__class__(self, self.prefactor) 405 new.append(obj/self.prefactor) 406 return new 407 elif obj.vartype == 2: # obj is a MultVariable 408 new = AddVariable(self, self.prefactor) 409 if self.prefactor == 1: 410 new.append(obj) 411 else: 412 new.append((1/self.prefactor)*obj) 413 return new 414 elif obj.vartype == 1: # obj is a AddVariable 415 new = AddVariable(self, self.prefactor) 416 for item in obj: 417 new.append(obj.prefactor/self.prefactor * item) 418 return new 419 else: 420 #force to look at obj + self 421 return NotImplemented
422
423 - def __iadd__(self, obj):
424 """Define all the different addition.""" 425 426 if not hasattr(obj, 'vartype'): 427 if not obj: # obj is zero 428 return self 429 self.append(obj/self.prefactor) 430 return self 431 elif obj.vartype == 2: # obj is a MultVariable 432 if self.prefactor == 1: 433 self.append(obj) 434 else: 435 self.append((1/self.prefactor)*obj) 436 return self 437 elif obj.vartype == 1: # obj is a AddVariable 438 for item in obj: 439 self.append(obj.prefactor/self.prefactor * item) 440 return self 441 else: 442 #force to look at obj + self 443 return NotImplemented
444
445 - def __sub__(self, obj):
446 return self + (-1) * obj
447
448 - def __rsub__(self, obj):
449 return (-1) * self + obj
450 451 __radd__ = __add__ 452 __rmul__ = __mul__ 453 454
455 - def __div__(self, obj):
456 return self.__mul__(1/obj)
457 458 __truediv__ = __div__ 459
460 - def __rdiv__(self, obj):
461 return self.__rmult__(1/obj)
462
463 - def __str__(self):
464 text = '' 465 if self.prefactor != 1: 466 text += str(self.prefactor) + ' * ' 467 text += '( ' 468 text += ' + '.join([str(item) for item in self]) 469 text += ' )' 470 return text
471
472 - def count_term(self):
473 # Count the number of appearance of each variable and find the most 474 #present one in order to factorize her 475 count = defaultdict(int) 476 correlation = defaultdict(defaultdict(int)) 477 for i,term in enumerate(self): 478 try: 479 set_term = set(term) 480 except TypeError: 481 #constant term 482 continue 483 for val1 in set_term: 484 count[val1] +=1 485 # allow to find optimized factorization for identical count 486 for val2 in set_term: 487 correlation[val1][val2] += 1 488 489 maxnb = max(count.values()) if count else 0 490 possibility = [v for v,val in count.items() if val == maxnb] 491 if maxnb == 1: 492 return 1, None 493 elif len(possibility) == 1: 494 return maxnb, possibility[0] 495 #import random 496 #return maxnb, random.sample(possibility,1)[0] 497 498 #return maxnb, possibility[0] 499 max_wgt, maxvar = 0, None 500 for var in possibility: 501 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var]) 502 if wgt > max_wgt: 503 maxvar = var 504 max_wgt = wgt 505 str_maxvar = str(KERNEL.objs[var]) 506 elif wgt == max_wgt: 507 # keep the one with the lowest string expr 508 new_str = str(KERNEL.objs[var]) 509 if new_str < str_maxvar: 510 maxvar = var 511 str_maxvar = new_str 512 return maxnb, maxvar
513
514 - def factorize(self):
515 """ try to factorize as much as possible the expression """ 516 517 max, maxvar = self.count_term() 518 if max <= 1: 519 #no factorization possible 520 return self 521 else: 522 # split in MAXVAR * NEWADD + CONSTANT 523 newadd = AddVariable() 524 constant = AddVariable() 525 #fill NEWADD and CONSTANT 526 for term in self: 527 try: 528 term.remove(maxvar) 529 except Exception: 530 constant.append(term) 531 else: 532 if len(term): 533 newadd.append(term) 534 else: 535 newadd.append(term.prefactor) 536 newadd = newadd.factorize() 537 538 # optimize the prefactor 539 if isinstance(newadd, AddVariable): 540 countprefact = defaultdict(int) 541 nbplus, nbminus = 0,0 542 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]: 543 countprefact[abs(nb)] +=1 544 if nb.real + nb.imag > 0: 545 nbplus += 1 546 else: 547 nbminus += 1 548 549 newadd.prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0] 550 if nbplus < nbminus: 551 newadd.prefactor *= -1 552 if newadd.prefactor != 1: 553 for i,a in enumerate(newadd): 554 try: 555 a.prefactor /= newadd.prefactor 556 except AttributeError: 557 newadd[i] /= newadd.prefactor 558 559 560 if len(constant) > 1: 561 constant = constant.factorize() 562 elif constant: 563 constant = constant[0] 564 else: 565 out = MultContainer([KERNEL.objs[maxvar], newadd]) 566 out.prefactor = self.prefactor 567 if newadd.prefactor != 1: 568 out.prefactor *= newadd.prefactor 569 newadd.prefactor = 1 570 return out 571 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant], 572 self.prefactor) 573 return out
574
575 -class MultContainer(list):
576 577 vartype = 6 578
579 - def __init__(self,*args):
580 self.prefactor =1 581 list.__init__(self, *args)
582
583 - def __str__(self):
584 """ String representation """ 585 if self.prefactor !=1: 586 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self])) 587 else: 588 text = '(%s)' % (' * '.join([str(t) for t in self])) 589 return text
590
591 - def factorize(self):
592 self[:] = [term.factorize() for term in self]
593
594 595 -class MultVariable(array):
596 """ A list of Variable with multiplication as operator between themselves. 597 Represented by array for speed optimization 598 """ 599 vartype=2 600 addclass = AddVariable 601
602 - def __new__(cls, old=[], prefactor=1):
603 return array.__new__(cls, 'i', old)
604 605
606 - def __init__(self, old=[], prefactor=1):
607 """ initialization of the object with default value """ 608 #array.__init__(self, 'i', old) <- done already in new !! 609 self.prefactor = prefactor 610 assert isinstance(self.prefactor, (float,int,long,complex))
611
612 - def get_id(self):
613 assert len(self) == 1 614 return self[0]
615
616 - def sort(self):
617 a = list(self) 618 a.sort() 619 self[:] = array('i',a) 620 return self
621
622 - def simplify(self):
623 """ simplify the product""" 624 if not len(self): 625 return self.prefactor 626 return self
627
628 - def split(self, variables_id):
629 """return a dict with the key being the power associated to each variables 630 and the value being the object remaining after the suppression of all 631 the variable""" 632 633 key = tuple([self.count(i) for i in variables_id]) 634 arg = [id for id in self if id not in variables_id] 635 self[:] = array('i', arg) 636 return SplitCoefficient([(key,self)])
637
638 - def replace(self, id, expression):
639 """replace one object (identify by his id) by a given expression. 640 Note that expression cann't be zero. 641 """ 642 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult' 643 644 if expression.vartype == 1: # AddVariable 645 nb = self.count(id) 646 if not nb: 647 return self 648 for i in range(nb): 649 self.remove(id) 650 new = self 651 for i in range(nb): 652 new *= expression 653 return new 654 elif expression.vartype == 2: # MultLorentz 655 # be carefull about A -> A * B 656 nb = self.count(id) 657 for i in range(nb): 658 self.remove(id) 659 self.__imul__(expression) 660 return self 661 # elif expression.vartype == 0: # Variable 662 # new_id = expression.id 663 # assert new_id != id 664 # while 1: 665 # try: 666 # self.remove(id) 667 # except ValueError: 668 # break 669 # else: 670 # self.append(new_id) 671 # return self 672 else: 673 raise Exception, 'Cann\'t replace a Variable by %s' % type(expression)
674 675
676 - def get_all_var_names(self):
677 """return the list of variable used in this multiplication""" 678 return ['%s' % KERNEL.objs[n] for n in self]
679 680 681 682 #Defining rule of Multiplication
683 - def __mul__(self, obj):
684 """Define the multiplication with different object""" 685 686 if not hasattr(obj, 'vartype'): # should be a number 687 if obj: 688 return self.__class__(self, obj*self.prefactor) 689 else: 690 return 0 691 elif obj.vartype == 1: # obj is an AddVariable 692 new = obj.__class__([], self.prefactor*obj.prefactor) 693 old, self.prefactor = self.prefactor, 1 694 new[:] = [self * term for term in obj] 695 self.prefactor = old 696 return new 697 elif obj.vartype == 4: 698 return NotImplemented 699 700 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
701 702 __rmul__ = __mul__ 703
704 - def __imul__(self, obj):
705 """Define the multiplication with different object""" 706 707 if not hasattr(obj, 'vartype'): # should be a number 708 if obj: 709 self.prefactor *= obj 710 return self 711 else: 712 return 0 713 elif obj.vartype == 1: # obj is an AddVariable 714 new = obj.__class__([], self.prefactor * obj.prefactor) 715 self.prefactor = 1 716 new[:] = [self * term for term in obj] 717 return new 718 elif obj.vartype == 4: 719 return NotImplemented 720 721 self.prefactor *= obj.prefactor 722 return array.__iadd__(self, obj)
723
724 - def __pow__(self,value):
725 out = 1 726 for i in range(value): 727 out *= self 728 return out
729 730
731 - def __add__(self, obj):
732 """ define the adition with different object""" 733 734 if not obj: 735 return self 736 elif not hasattr(obj, 'vartype') or obj.vartype == 2: 737 new = self.addclass([self, obj]) 738 return new 739 else: 740 #call the implementation of addition implemented in obj 741 return NotImplemented
742 __radd__ = __add__ 743 __iadd__ = __add__ 744
745 - def __sub__(self, obj):
746 return self + (-1) * obj
747
748 - def __neg__(self):
749 self.prefactor *=-1 750 return self
751
752 - def __rsub__(self, obj):
753 return (-1) * self + obj
754
755 - def __idiv__(self,obj):
756 """ ONLY NUMBER DIVISION ALLOWED""" 757 assert not hasattr(obj, 'vartype') 758 self.prefactor /= obj 759 return self
760 761 __div__ = __idiv__ 762 __truediv__ = __div__ 763 764
765 - def __str__(self):
766 """ String representation """ 767 t = ['%s' % KERNEL.objs[n] for n in self] 768 if self.prefactor != 1: 769 text = '(%s * %s)' % (self.prefactor,' * '.join(t)) 770 else: 771 text = '(%s)' % (' * '.join(t)) 772 return text
773 774 __rep__ = __str__ 775
776 - def factorize(self):
777 return self
778
779 780 #=============================================================================== 781 # FactoryVar 782 #=============================================================================== 783 -class C_Variable(str):
784 vartype=0 785 type = 'complex'
786
787 -class R_Variable(str):
788 vartype=0 789 type = 'double'
790
791 -class ExtVariable(str):
792 vartype=0 793 type = 'parameter'
794
795 796 -class FactoryVar(object):
797 """This is the standard object for all the variable linked to expression. 798 """ 799 mult_class = MultVariable # The class for the multiplication 800
801 - def __new__(cls, name, baseclass, *args):
802 """Factory class return a MultVariable.""" 803 804 if name in KERNEL: 805 return cls.mult_class([KERNEL[name]]) 806 else: 807 obj = baseclass(name, *args) 808 id = KERNEL.add(name, obj) 809 obj.id = id 810 return cls.mult_class([id])
811
812 -class Variable(FactoryVar):
813
814 - def __new__(self, name, type=C_Variable):
815 return FactoryVar(name, type)
816
817 -class DVariable(FactoryVar):
818
819 - def __new__(self, name):
820 821 if aloha.complex_mass: 822 #some parameter are pass to complex 823 if name[0] in ['M','W'] or name.startswith('OM'): 824 return FactoryVar(name, C_Variable) 825 if aloha.loop_mode and name.startswith('P'): 826 return FactoryVar(name, C_Variable) 827 #Normal case: 828 return FactoryVar(name, R_Variable)
829
830 831 832 833 #=============================================================================== 834 # Object for Analytical Representation of Lorentz object (not scalar one) 835 #=============================================================================== 836 837 838 #=============================================================================== 839 # MultLorentz 840 #=============================================================================== 841 -class MultLorentz(MultVariable):
842 """Specific class for LorentzObject Multiplication""" 843 844 add_class = AddVariable # Define which class describe the addition 845
846 - def find_lorentzcontraction(self):
847 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 848 the contraction in this Multiplication.""" 849 850 out = {} 851 len_mult = len(self) 852 # Loop over the element 853 for i, fact in enumerate(self): 854 # and over the indices of this element 855 for j in range(len(fact.lorentz_ind)): 856 # in order to compare with the other element of the multiplication 857 for k in range(i+1,len_mult): 858 fact2 = self[k] 859 try: 860 l = fact2.lorentz_ind.index(fact.lorentz_ind[j]) 861 except Exception: 862 pass 863 else: 864 out[(i, j)] = (k, l) 865 out[(k, l)] = (i, j) 866 return out
867
868 - def find_spincontraction(self):
869 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 870 the contraction in this Multiplication.""" 871 872 out = {} 873 len_mult = len(self) 874 # Loop over the element 875 for i, fact in enumerate(self): 876 # and over the indices of this element 877 for j in range(len(fact.spin_ind)): 878 # in order to compare with the other element of the multiplication 879 for k in range(i+1, len_mult): 880 fact2 = self[k] 881 try: 882 l = fact2.spin_ind.index(fact.spin_ind[j]) 883 except Exception: 884 pass 885 else: 886 out[(i, j)] = (k, l) 887 out[(k, l)] = (i, j) 888 889 return out
890
891 - def neighboor(self, home):
892 """return one variable which are contracted with var and not yet expanded""" 893 894 for var in self.unused: 895 obj = KERNEL.objs[var] 896 if obj.has_component(home.lorentz_ind, home.spin_ind): 897 return obj 898 return None
899 900 901 902
903 - def expand(self, veto=[]):
904 """ expand each part of the product and combine them. 905 Try to use a smart order in order to minimize the number of uncontracted indices. 906 Veto forbids the use of sub-expression if it contains some of the variable in the 907 expression. Veto contains the id of the vetoed variables 908 """ 909 910 self.unused = self[:] # list of not expanded 911 # made in a list the interesting starting point for the computation 912 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first] 913 product_term = [] #store result of intermediate chains 914 current = None # current point in the working chain 915 916 while self.unused: 917 #Loop untill we have expand everything 918 if not current: 919 # First we need to have a starting point 920 try: 921 # look in priority in basic_end_point (P/S/fermion/...) 922 current = basic_end_point.pop() 923 except Exception: 924 #take one of the remaining 925 current = self.unused.pop() 926 else: 927 #check that this one is not already use 928 if current not in self.unused: 929 current = None 930 continue 931 #remove of the unuse (usualy done in the pop) 932 self.unused.remove(current) 933 cur_obj = KERNEL.objs[current] 934 # initialize the new chain 935 product_term.append(cur_obj.expand()) 936 937 # We have a point -> find the next one 938 var_obj = self.neighboor(product_term[-1]) 939 # provide one term which is contracted with current and which is not 940 #yet expanded. 941 if var_obj: 942 product_term[-1] *= var_obj.expand() 943 cur_obj = var_obj 944 self.unused.remove(cur_obj.id) 945 continue 946 947 current = None 948 949 950 # Multiply all those current 951 # For Fermion/Vector only one can carry index. 952 out = self.prefactor 953 for fact in product_term[:]: 954 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []: 955 scalar = fact.get_rep([0]) 956 if hasattr(scalar, 'vartype') and scalar.vartype == 1: 957 if not veto or not scalar.contains(veto): 958 scalar = scalar.simplify() 959 prefactor = 1 960 961 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]: 962 prefactor = scalar.prefactor 963 scalar.prefactor = 1 964 new = KERNEL.add_expression_contraction(scalar) 965 fact.set_rep([0], prefactor * new) 966 out *= fact 967 return out
968
969 - def __copy__(self):
970 """ create a shadow copy """ 971 new = MultLorentz(self) 972 new.prefactor = self.prefactor 973 return new
974
975 #=============================================================================== 976 # LorentzObject 977 #=============================================================================== 978 -class LorentzObject(object):
979 """ A symbolic Object for All Helas object. All Helas Object Should 980 derivated from this class""" 981 982 contract_first = 0 983 mult_class = MultLorentz # The class for the multiplication 984 add_class = AddVariable # The class for the addition 985
986 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
987 """ initialization of the object with default value """ 988 assert isinstance(lor_ind, list) 989 assert isinstance(spin_ind, list) 990 991 self.name = name 992 self.lorentz_ind = lor_ind 993 self.spin_ind = spin_ind 994 KERNEL.add_tag(set(tags))
995
996 - def expand(self):
997 """Expand the content information into LorentzObjectRepresentation.""" 998 999 try: 1000 return self.representation 1001 except Exception: 1002 self.create_representation() 1003 return self.representation
1004
1005 - def create_representation(self):
1006 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
1007
1008 - def has_component(self, lor_list, spin_list):
1009 """check if this Lorentz Object have some of those indices""" 1010 1011 if any([id in self.lorentz_ind for id in lor_list]) or \ 1012 any([id in self.spin_ind for id in spin_list]): 1013 return True
1014 1015 1016
1017 - def __str__(self):
1018 return '%s' % self.name
1019
1020 -class FactoryLorentz(FactoryVar):
1021 """ A symbolic Object for All Helas object. All Helas Object Should 1022 derivated from this class""" 1023 1024 mult_class = MultLorentz # The class for the multiplication 1025 object_class = LorentzObject # Define How to create the basic object. 1026
1027 - def __new__(cls, *args):
1028 name = cls.get_unique_name(*args) 1029 return FactoryVar.__new__(cls, name, cls.object_class, *args)
1030 1031 @classmethod
1032 - def get_unique_name(cls, *args):
1033 """default way to have a unique name""" 1034 return '_L_%(class)s_%(args)s' % \ 1035 {'class':cls.__name__, 1036 'args': '_'.join(args) 1037 }
1038
1039 1040 #=============================================================================== 1041 # LorentzObjectRepresentation 1042 #=============================================================================== 1043 -class LorentzObjectRepresentation(dict):
1044 """A concrete representation of the LorentzObject.""" 1045 1046 vartype = 4 # Optimization for instance recognition 1047
1048 - class LorentzObjectRepresentationError(Exception):
1049 """Specify error for LorentzObjectRepresentation"""
1050
1051 - def __init__(self, representation, lorentz_indices, spin_indices):
1052 """ initialize the lorentz object representation""" 1053 1054 self.lorentz_ind = lorentz_indices #lorentz indices 1055 self.nb_lor = len(lorentz_indices) #their number 1056 self.spin_ind = spin_indices #spin indices 1057 self.nb_spin = len(spin_indices) #their number 1058 self.nb_ind = self.nb_lor + self.nb_spin #total number of indices 1059 1060 #store the representation 1061 if self.lorentz_ind or self.spin_ind: 1062 dict.__init__(self, representation) 1063 elif isinstance(representation,dict): 1064 if len(representation) == 0: 1065 self[(0,)] = 0 1066 elif len(representation) == 1 and (0,) in representation: 1067 self[(0,)] = representation[(0,)] 1068 else: 1069 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1070 else: 1071 if isinstance(representation,dict): 1072 try: 1073 self[(0,)] = representation[(0,)] 1074 except Exception: 1075 if representation: 1076 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1077 else: 1078 self[(0,)] = 0 1079 else: 1080 self[(0,)] = representation
1081
1082 - def __str__(self):
1083 """ string representation """ 1084 text = 'lorentz index :' + str(self.lorentz_ind) + '\n' 1085 text += 'spin index :' + str(self.spin_ind) + '\n' 1086 #text += 'other info ' + str(self.tag) + '\n' 1087 for ind in self.listindices(): 1088 ind = tuple(ind) 1089 text += str(ind) + ' --> ' 1090 text += str(self.get_rep(ind)) + '\n' 1091 return text
1092
1093 - def get_rep(self, indices):
1094 """return the value/Variable associate to the indices""" 1095 return self[tuple(indices)]
1096
1097 - def set_rep(self, indices, value):
1098 """assign 'value' at the indices position""" 1099 1100 self[tuple(indices)] = value
1101
1102 - def listindices(self):
1103 """Return an iterator in order to be able to loop easily on all the 1104 indices of the object.""" 1105 return IndicesIterator(self.nb_ind)
1106 1107 @staticmethod
1108 - def get_mapping(l1,l2, switch_order=[]):
1109 shift = len(switch_order) 1110 for value in l1: 1111 try: 1112 index = l2.index(value) 1113 except Exception: 1114 raise LorentzObjectRepresentation.LorentzObjectRepresentationError( 1115 "Invalid addition. Object doen't have the same lorentz "+ \ 1116 "indices : %s != %s" % (l1, l2)) 1117 else: 1118 switch_order.append(shift + index) 1119 return switch_order
1120 1121
1122 - def __add__(self, obj, fact=1):
1123 1124 if not obj: 1125 return self 1126 1127 if not hasattr(obj, 'vartype'): 1128 assert self.lorentz_ind == [] 1129 assert self.spin_ind == [] 1130 new = self[(0,)] + obj * fact 1131 out = LorentzObjectRepresentation(new, [], []) 1132 return out 1133 1134 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1135 1136 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1137 # if the order of indices are different compute a mapping 1138 switch_order = [] 1139 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order) 1140 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order) 1141 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1142 else: 1143 # no mapping needed (define switch as identity) 1144 switch = lambda ind : (ind) 1145 1146 # Some sanity check 1147 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind)) 1148 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind)) 1149 1150 # define an empty representation 1151 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind) 1152 1153 # loop over all indices and fullfill the new object 1154 if fact == 1: 1155 for ind in self.listindices(): 1156 value = obj.get_rep(ind) + self.get_rep(switch(ind)) 1157 new.set_rep(ind, value) 1158 else: 1159 for ind in self.listindices(): 1160 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind) 1161 new.set_rep(ind, value) 1162 1163 return new
1164
1165 - def __iadd__(self, obj, fact=1):
1166 1167 if not obj: 1168 return self 1169 1170 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1171 1172 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1173 1174 # if the order of indices are different compute a mapping 1175 switch_order = [] 1176 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order) 1177 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order) 1178 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1179 else: 1180 # no mapping needed (define switch as identity) 1181 switch = lambda ind : (ind) 1182 1183 # Some sanity check 1184 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind)) 1185 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind)) 1186 1187 # loop over all indices and fullfill the new object 1188 if fact == 1: 1189 for ind in self.listindices(): 1190 self[tuple(ind)] += obj.get_rep(switch(ind)) 1191 else: 1192 for ind in self.listindices(): 1193 self[tuple(ind)] += fact * obj.get_rep(switch(ind)) 1194 return self
1195
1196 - def __sub__(self, obj):
1197 return self.__add__(obj, fact= -1)
1198
1199 - def __rsub__(self, obj):
1200 return obj.__add__(self, fact= -1)
1201
1202 - def __isub__(self, obj):
1203 return self.__add__(obj, fact= -1)
1204
1205 - def __neg__(self):
1206 self *= -1 1207 return self
1208
1209 - def __mul__(self, obj):
1210 """multiplication performing directly the einstein/spin sommation. 1211 """ 1212 1213 if not hasattr(obj, 'vartype'): 1214 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind) 1215 for ind in out.listindices(): 1216 out.set_rep(ind, obj * self.get_rep(ind)) 1217 return out 1218 1219 # Sanity Check 1220 assert(obj.__class__ == LorentzObjectRepresentation), \ 1221 '%s is not valid class for this operation' %type(obj) 1222 1223 # compute information on the status of the index (which are contracted/ 1224 #not contracted 1225 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \ 1226 obj.lorentz_ind) 1227 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \ 1228 obj.spin_ind) 1229 if not(sum_l_ind or sum_s_ind): 1230 # No contraction made a tensor product 1231 return self.tensor_product(obj) 1232 1233 # elsewher made a spin contraction 1234 # create an empty representation but with correct indices 1235 new_object = LorentzObjectRepresentation({}, l_ind, s_ind) 1236 #loop and fullfill the representation 1237 for indices in new_object.listindices(): 1238 #made a dictionary (pos -> index_value) for how call the object 1239 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind) 1240 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind) 1241 #add the new value 1242 new_object.set_rep(indices, \ 1243 self.contraction(obj, sum_l_ind, sum_s_ind, \ 1244 dict_l_ind, dict_s_ind)) 1245 1246 return new_object
1247 1248 __rmul__ = __mul__ 1249 __imul__ = __mul__ 1250
1251 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1252 """ make the Lorentz/spin contraction of object self and obj. 1253 l_sum/s_sum are the position of the sum indices 1254 l_dict/s_dict are dict given the value of the fix indices (indices->value) 1255 """ 1256 out = 0 # initial value for the output 1257 len_l = len(l_sum) #store len for optimization 1258 len_s = len(s_sum) # same 1259 1260 # loop over the possibility for the sum indices and update the dictionary 1261 # (indices->value) 1262 for l_value in IndicesIterator(len_l): 1263 l_dict.update(self.pass_ind_in_dict(l_value, l_sum)) 1264 for s_value in IndicesIterator(len_s): 1265 #s_dict_final = s_dict.copy() 1266 s_dict.update(self.pass_ind_in_dict(s_value, s_sum)) 1267 1268 #return the indices in the correct order 1269 self_ind = self.combine_indices(l_dict, s_dict) 1270 obj_ind = obj.combine_indices(l_dict, s_dict) 1271 1272 # call the object 1273 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind) 1274 1275 if factor: 1276 #compute the prefactor due to the lorentz contraction 1277 try: 1278 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0)) 1279 except Exception: 1280 factor *= (-1) ** (len(l_value) - l_value.count(0)) 1281 out += factor 1282 return out
1283
1284 - def tensor_product(self, obj):
1285 """ return the tensorial product of the object""" 1286 assert(obj.vartype == 4) #isinstance(obj, LorentzObjectRepresentation)) 1287 1288 new_object = LorentzObjectRepresentation({}, \ 1289 self.lorentz_ind + obj.lorentz_ind, \ 1290 self.spin_ind + obj.spin_ind) 1291 1292 #some shortcut 1293 lor1 = self.nb_lor 1294 lor2 = obj.nb_lor 1295 spin1 = self.nb_spin 1296 spin2 = obj.nb_spin 1297 1298 #define how to call build the indices first for the first object 1299 if lor1 == 0 == spin1: 1300 #special case for scalar 1301 selfind = lambda indices: [0] 1302 else: 1303 selfind = lambda indices: indices[:lor1] + \ 1304 indices[lor1 + lor2: lor1 + lor2 + spin1] 1305 1306 #then for the second 1307 if lor2 == 0 == spin2: 1308 #special case for scalar 1309 objind = lambda indices: [0] 1310 else: 1311 objind = lambda indices: indices[lor1: lor1 + lor2] + \ 1312 indices[lor1 + lor2 + spin1:] 1313 1314 # loop on the indices and assign the product 1315 for indices in new_object.listindices(): 1316 1317 fac1 = self.get_rep(tuple(selfind(indices))) 1318 fac2 = obj.get_rep(tuple(objind(indices))) 1319 new_object.set_rep(indices, fac1 * fac2) 1320 1321 return new_object
1322
1323 - def factorize(self):
1324 """Try to factorize each component""" 1325 for ind, fact in self.items(): 1326 if fact: 1327 self.set_rep(ind, fact.factorize()) 1328 1329 1330 return self
1331
1332 - def simplify(self):
1333 """Check if we can simplify the object (check for non treated Sum)""" 1334 1335 #Look for internal simplification 1336 for ind, term in self.items(): 1337 if hasattr(term, 'vartype'): 1338 self[ind] = term.simplify() 1339 #no additional simplification 1340 return self
1341 1342 @staticmethod
1343 - def compare_indices(list1, list2):
1344 """return two list, the first one contains the position of non summed 1345 index and the second one the position of summed index.""" 1346 #init object 1347 1348 # equivalent set call --slightly slower 1349 #return list(set(list1) ^ set(list2)), list(set(list1) & set(list2)) 1350 1351 1352 are_unique, are_sum = [], [] 1353 # loop over the first list and check if they are in the second list 1354 1355 for indice in list1: 1356 if indice in list2: 1357 are_sum.append(indice) 1358 else: 1359 are_unique.append(indice) 1360 # loop over the second list for additional unique item 1361 1362 for indice in list2: 1363 if indice not in are_sum: 1364 are_unique.append(indice) 1365 1366 # return value 1367 return are_unique, are_sum
1368 1369 @staticmethod
1370 - def pass_ind_in_dict(indices, key):
1371 """made a dictionary (pos -> index_value) for how call the object""" 1372 if not key: 1373 return {} 1374 out = {} 1375 for i, ind in enumerate(indices): 1376 out[key[i]] = ind 1377 return out
1378
1379 - def combine_indices(self, l_dict, s_dict):
1380 """return the indices in the correct order following the dicts rules""" 1381 1382 out = [] 1383 # First for the Lorentz indices 1384 for value in self.lorentz_ind: 1385 out.append(l_dict[value]) 1386 # Same for the spin 1387 for value in self.spin_ind: 1388 out.append(s_dict[value]) 1389 1390 return out
1391
1392 - def split(self, variables_id):
1393 """return a dict with the key being the power associated to each variables 1394 and the value being the object remaining after the suppression of all 1395 the variable""" 1396 1397 out = SplitCoefficient() 1398 zero_rep = {} 1399 for ind in self.listindices(): 1400 zero_rep[tuple(ind)] = 0 1401 1402 for ind in self.listindices(): 1403 # There is no function split if the element is just a simple number 1404 if isinstance(self.get_rep(ind), numbers.Number): 1405 if tuple([0]*len(variables_id)) in out: 1406 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind) 1407 else: 1408 out[tuple([0]*len(variables_id))] = \ 1409 LorentzObjectRepresentation(dict(zero_rep), 1410 self.lorentz_ind, self.spin_ind) 1411 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind) 1412 continue 1413 1414 for key, value in self.get_rep(ind).split(variables_id).items(): 1415 if key in out: 1416 out[key][tuple(ind)] += value 1417 else: 1418 out[key] = LorentzObjectRepresentation(dict(zero_rep), 1419 self.lorentz_ind, self.spin_ind) 1420 out[key][tuple(ind)] += value 1421 1422 return out
1423
1424 1425 1426 1427 #=============================================================================== 1428 # IndicesIterator 1429 #=============================================================================== 1430 -class IndicesIterator:
1431 """Class needed for the iterator""" 1432
1433 - def __init__(self, len):
1434 """ create an iterator looping over the indices of a list of len "len" 1435 with each value can take value between 0 and 3 """ 1436 1437 self.len = len # number of indices 1438 if len: 1439 # initialize the position. The first position is -1 due to the method 1440 #in place which start by rising an index before returning smtg 1441 self.data = [-1] + [0] * (len - 1) 1442 else: 1443 # Special case for Scalar object 1444 self.data = 0 1445 self.next = self.nextscalar
1446
1447 - def __iter__(self):
1448 return self
1449
1450 - def next(self):
1451 for i in range(self.len): 1452 if self.data[i] < 3: 1453 self.data[i] += 1 1454 return self.data 1455 else: 1456 self.data[i] = 0 1457 raise StopIteration
1458
1459 - def nextscalar(self):
1460 if self.data: 1461 raise StopIteration 1462 else: 1463 self.data = True 1464 return [0]
1465
1466 -class SplitCoefficient(dict):
1467
1468 - def __init__(self, *args, **opt):
1469 dict.__init__(self, *args, **opt) 1470 self.tag=set()
1471
1472 - def get_max_rank(self):
1473 """return the highest rank of the coefficient""" 1474 1475 return max([max(arg[:4]) for arg in self])
1476 1477 1478 if '__main__' ==__name__: 1479 1480 import cProfile
1481 - def create():
1482 for i in range(10000): 1483 LorentzObjectRepresentation.compare_indices(range(i%10),[4,3,5])
1484 1485 cProfile.run('create()') 1486