1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 from __future__ import division
46 from array import array
47 import collections
48 from fractions import Fraction
49 import numbers
50 import re
51 import aloha
52
53 try:
54 import madgraph.various.misc as misc
55 except Exception:
56 import aloha.misc as misc
62
64 """ a class to encapsulate all computation. Limit side effect """
65
67 self.objs = []
68 self.use_tag = set()
69 self.id = -1
70 self.reduced_expr = {}
71 self.fct_expr = {}
72 self.reduced_expr2 = {}
73 self.inverted_fct = {}
74 self.has_pi = False
75 self.unknow_fct = []
76 dict.__init__(self)
77
81
82 - def add(self, name, obj):
83 self.id += 1
84 self.objs.append(obj)
85 self[name] = self.id
86 return self.id
87
88 - def get(self, name):
89 return self.objs[self[name]]
90
93
95 """return the list of identification number associate to the
96 given variables names. If a variable didn't exists, create it (in complex).
97 """
98 out = []
99 for var in variables:
100 try:
101 id = self[var]
102 except KeyError:
103 assert var not in ['M','W']
104 id = Variable(var).get_id()
105 out.append(id)
106 return out
107
108
110
111 str_expr = str(expression)
112 if str_expr in self.reduced_expr:
113 out, tag = self.reduced_expr[str_expr]
114 self.add_tag((tag,))
115 return out
116 if expression == 0:
117 return 0
118 new_2 = expression.simplify()
119 if new_2 == 0:
120 return 0
121
122 tag = 'TMP%s' % len(self.reduced_expr)
123 new = Variable(tag)
124 self.reduced_expr[str_expr] = [new, tag]
125 new_2 = new_2.factorize()
126 self.reduced_expr2[tag] = new_2
127 self.add_tag((tag,))
128
129
130 return new
131
132 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot',
133 'theta_function', 'exp']
135
136 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or
137 (fct_tag, len(args)) in self.unknow_fct):
138 self.unknow_fct.append( (fct_tag, len(args)) )
139
140 argument = []
141 for expression in args:
142 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)):
143 try:
144 expr = expression.expand().get_rep([0])
145 except KeyError, error:
146 if error.args != ((0,),):
147 raise
148 else:
149 raise aloha.ALOHAERROR, '''Error in input format.
150 Argument of function (or denominator) should be scalar.
151 We found %s''' % expression
152 new = expr.simplify()
153 if not isinstance(new, numbers.Number):
154 new = new.factorize()
155 argument.append(new)
156 else:
157 argument.append(expression)
158 for arg in argument:
159 val = re.findall(r'''\bFCT(\d*)\b''', str(arg))
160 for v in val:
161 self.add_tag(('FCT%s' % v,))
162
163
164 if (fct_tag.startswith('cmath.') or fct_tag in self.known_fct) and \
165 all(isinstance(x, numbers.Number) for x in argument):
166 import cmath
167 if fct_tag.startswith('cmath.'):
168 module = ''
169 else:
170 module = 'cmath.'
171 try:
172 return str(eval("%s%s(%s)" % (module,fct_tag, ','.join(`x` for x in argument))))
173 except Exception, error:
174 print error
175 print "cmath.%s(%s)" % (fct_tag, ','.join(`x` for x in argument))
176 if str(fct_tag)+str(argument) in self.inverted_fct:
177 tag = self.inverted_fct[str(fct_tag)+str(argument)]
178 v = tag.split('(')[1][:-1]
179 self.add_tag(('FCT%s' % v,))
180 return tag
181 else:
182 id = len(self.fct_expr)
183 tag = 'FCT%s' % id
184 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id
185 self.fct_expr[tag] = (fct_tag, argument)
186 self.reduced_expr2[tag] = (fct_tag, argument)
187 self.add_tag((tag,))
188 return 'FCT(%s)' % id
189
190 KERNEL = Computation()
196 """ A list of Variable/ConstantObject/... This object represent the operation
197 between those object."""
198
199
200 vartype = 1
201
202 - def __init__(self, old_data=[], prefactor=1):
203 """ initialization of the object with default value """
204
205 self.prefactor = prefactor
206
207 list.__init__(self, old_data)
208
210 """ apply rule of simplification """
211
212
213 if len(self) == 1:
214 return self.prefactor * self[0].simplify()
215 constant = 0
216 items = {}
217 pos = -1
218 for term in self[:]:
219 pos += 1
220 if not hasattr(term, 'vartype'):
221 if isinstance(term, dict):
222
223 assert term.values() == [0]
224 term = term[(0,)]
225 constant += term
226 del self[pos]
227 pos -= 1
228 continue
229 tag = tuple(term.sort())
230 if tag in items:
231 orig_prefac = items[tag].prefactor
232 items[tag].prefactor += term.prefactor
233 if items[tag].prefactor and \
234 abs(items[tag].prefactor) / (abs(orig_prefac)+abs(term.prefactor)) < 1e-8:
235 items[tag].prefactor = 0
236 del self[pos]
237 pos -=1
238 else:
239 items[tag] = term.__class__(term, term.prefactor)
240 self[pos] = items[tag]
241
242
243 countprefact = defaultdict(int)
244 nbplus, nbminus = 0,0
245 if constant not in [0, 1,-1]:
246 countprefact[constant] += 1
247 if constant.real + constant.imag > 0:
248 nbplus += 1
249 else:
250 nbminus += 1
251
252 for var in items.values():
253 if var.prefactor == 0:
254 self.remove(var)
255 else:
256 nb = var.prefactor
257 if nb in [1,-1]:
258 continue
259 countprefact[abs(nb)] +=1
260 if nb.real + nb.imag > 0:
261 nbplus += 1
262 else:
263 nbminus += 1
264 if countprefact and max(countprefact.values()) >1:
265 fact_prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0]
266 else:
267 fact_prefactor = 1
268 if nbplus < nbminus:
269 fact_prefactor *= -1
270 self.prefactor *= fact_prefactor
271
272 if fact_prefactor != 1:
273 for i,a in enumerate(self):
274 try:
275 a.prefactor /= fact_prefactor
276 except AttributeError:
277 self[i] /= fact_prefactor
278
279 if constant:
280 self.append(constant/ fact_prefactor )
281
282
283 varlen = len(self)
284 if varlen == 1:
285 if hasattr(self[0], 'vartype'):
286 return self.prefactor * self[0].simplify()
287 else:
288
289 return self.prefactor * self[0]
290 elif varlen == 0:
291 return 0
292 return self
293
294 - def split(self, variables_id):
295 """return a dict with the key being the power associated to each variables
296 and the value being the object remaining after the suppression of all
297 the variable"""
298
299 out = defaultdict(int)
300 for obj in self:
301 for key, value in obj.split(variables_id).items():
302 out[key] += self.prefactor * value
303 return out
304
306 """returns true if one of the variables is in the expression"""
307
308 return any((v in obj for obj in self for v in variables ))
309
310
312
313 out = []
314 for term in self:
315 if hasattr(term, 'get_all_var_names'):
316 out += term.get_all_var_names()
317 return out
318
319
320
321 - def replace(self, id, expression):
322 """replace one object (identify by his id) by a given expression.
323 Note that expression cann't be zero.
324 Note that this should be canonical form (this should contains ONLY
325 MULTVARIABLE) --so this should be called before a factorize.
326 """
327 new = self.__class__()
328
329 for obj in self:
330 assert isinstance(obj, MultVariable)
331 tmp = obj.replace(id, expression)
332 new += tmp
333 new.prefactor = self.prefactor
334 return new
335
336
338 """Pass from High level object to low level object"""
339
340 if not self:
341 return self
342 if self.prefactor == 1:
343 new = self[0].expand(veto)
344 else:
345 new = self.prefactor * self[0].expand(veto)
346
347 for item in self[1:]:
348 if self.prefactor == 1:
349 try:
350 new += item.expand(veto)
351 except AttributeError:
352 new = new + item
353
354 else:
355 new += (self.prefactor) * item.expand(veto)
356 return new
357
359 """define the multiplication of
360 - a AddVariable with a number
361 - a AddVariable with an AddVariable
362 other type of multiplication are define via the symmetric operation base
363 on the obj class."""
364
365
366 if not hasattr(obj, 'vartype'):
367 if not obj:
368 return 0
369 return self.__class__(self, self.prefactor*obj)
370 elif obj.vartype == 1:
371 new = self.__class__([],self.prefactor * obj.prefactor)
372 new[:] = [i*j for i in self for j in obj]
373 return new
374 else:
375
376 return NotImplemented
377
379 """define the multiplication of
380 - a AddVariable with a number
381 - a AddVariable with an AddVariable
382 other type of multiplication are define via the symmetric operation base
383 on the obj class."""
384
385 if not hasattr(obj, 'vartype'):
386 if not obj:
387 return 0
388 self.prefactor *= obj
389 return self
390 elif obj.vartype == 1:
391 new = self.__class__([], self.prefactor * obj.prefactor)
392 new[:] = [i*j for i in self for j in obj]
393 return new
394 else:
395
396 return NotImplemented
397
399 self.prefactor *= -1
400 return self
401
403 """Define all the different addition."""
404
405 if not hasattr(obj, 'vartype'):
406 if not obj:
407 return self
408 new = self.__class__(self, self.prefactor)
409 new.append(obj/self.prefactor)
410 return new
411 elif obj.vartype == 2:
412 new = AddVariable(self, self.prefactor)
413 if self.prefactor == 1:
414 new.append(obj)
415 else:
416 new.append((1/self.prefactor)*obj)
417 return new
418 elif obj.vartype == 1:
419 new = AddVariable(self, self.prefactor)
420 for item in obj:
421 new.append(obj.prefactor/self.prefactor * item)
422 return new
423 else:
424
425 return NotImplemented
426
428 """Define all the different addition."""
429
430 if not hasattr(obj, 'vartype'):
431 if not obj:
432 return self
433 self.append(obj/self.prefactor)
434 return self
435 elif obj.vartype == 2:
436 if self.prefactor == 1:
437 self.append(obj)
438 else:
439 self.append((1/self.prefactor)*obj)
440 return self
441 elif obj.vartype == 1:
442 for item in obj:
443 self.append(obj.prefactor/self.prefactor * item)
444 return self
445 else:
446
447 return NotImplemented
448
450 return self + (-1) * obj
451
453 return (-1) * self + obj
454
455 __radd__ = __add__
456 __rmul__ = __mul__
457
458
461
462 __truediv__ = __div__
463
465 return self.__rmult__(1/obj)
466
468 text = ''
469 if self.prefactor != 1:
470 text += str(self.prefactor) + ' * '
471 text += '( '
472 text += ' + '.join([str(item) for item in self])
473 text += ' )'
474 return text
475
477
478
479 count = defaultdict(int)
480 correlation = defaultdict(defaultdict(int))
481 for i,term in enumerate(self):
482 try:
483 set_term = set(term)
484 except TypeError:
485
486 continue
487 for val1 in set_term:
488 count[val1] +=1
489
490 for val2 in set_term:
491 correlation[val1][val2] += 1
492
493 maxnb = max(count.values()) if count else 0
494 possibility = [v for v,val in count.items() if val == maxnb]
495 if maxnb == 1:
496 return 1, None
497 elif len(possibility) == 1:
498 return maxnb, possibility[0]
499
500
501
502
503 max_wgt, maxvar = 0, None
504 for var in possibility:
505 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var])
506 if wgt > max_wgt:
507 maxvar = var
508 max_wgt = wgt
509 str_maxvar = str(KERNEL.objs[var])
510 elif wgt == max_wgt:
511
512 new_str = str(KERNEL.objs[var])
513 if new_str < str_maxvar:
514 maxvar = var
515 str_maxvar = new_str
516 return maxnb, maxvar
517
519 """ try to factorize as much as possible the expression """
520
521 max, maxvar = self.count_term()
522 if max <= 1:
523
524 return self
525 else:
526
527 newadd = AddVariable()
528 constant = AddVariable()
529
530 for term in self:
531 try:
532 term.remove(maxvar)
533 except Exception:
534 constant.append(term)
535 else:
536 if len(term):
537 newadd.append(term)
538 else:
539 newadd.append(term.prefactor)
540 newadd = newadd.factorize()
541
542
543 if isinstance(newadd, AddVariable):
544 countprefact = defaultdict(int)
545 nbplus, nbminus = 0,0
546 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]:
547 countprefact[abs(nb)] +=1
548 if nb.real + nb.imag > 0:
549 nbplus += 1
550 else:
551 nbminus += 1
552
553 newadd.prefactor = sorted(countprefact.items(), key=lambda x: x[1], reverse=True)[0][0]
554 if nbplus < nbminus:
555 newadd.prefactor *= -1
556 if newadd.prefactor != 1:
557 for i,a in enumerate(newadd):
558 try:
559 a.prefactor /= newadd.prefactor
560 except AttributeError:
561 newadd[i] /= newadd.prefactor
562
563
564 if len(constant) > 1:
565 constant = constant.factorize()
566 elif constant:
567 constant = constant[0]
568 else:
569 out = MultContainer([KERNEL.objs[maxvar], newadd])
570 out.prefactor = self.prefactor
571 if newadd.prefactor != 1:
572 out.prefactor *= newadd.prefactor
573 newadd.prefactor = 1
574 return out
575 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant],
576 self.prefactor)
577 return out
578
580
581 vartype = 6
582
586
588 """ String representation """
589 if self.prefactor !=1:
590 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self]))
591 else:
592 text = '(%s)' % (' * '.join([str(t) for t in self]))
593 return text
594
596 self[:] = [term.factorize() for term in self]
597
600 """ A list of Variable with multiplication as operator between themselves.
601 Represented by array for speed optimization
602 """
603 vartype=2
604 addclass = AddVariable
605
606 - def __new__(cls, old=[], prefactor=1):
607 return array.__new__(cls, 'i', old)
608
609
610 - def __init__(self, old=[], prefactor=1):
611 """ initialization of the object with default value """
612
613 self.prefactor = prefactor
614 assert isinstance(self.prefactor, (float,int,long,complex))
615
617 assert len(self) == 1
618 return self[0]
619
621 a = list(self)
622 a.sort()
623 self[:] = array('i',a)
624 return self
625
627 """ simplify the product"""
628 if not len(self):
629 return self.prefactor
630 return self
631
632 - def split(self, variables_id):
633 """return a dict with the key being the power associated to each variables
634 and the value being the object remaining after the suppression of all
635 the variable"""
636
637 key = tuple([self.count(i) for i in variables_id])
638 arg = [id for id in self if id not in variables_id]
639 self[:] = array('i', arg)
640 return SplitCoefficient([(key,self)])
641
642 - def replace(self, id, expression):
643 """replace one object (identify by his id) by a given expression.
644 Note that expression cann't be zero.
645 """
646 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult'
647
648 if expression.vartype == 1:
649 nb = self.count(id)
650 if not nb:
651 return self
652 for i in range(nb):
653 self.remove(id)
654 new = self
655 for i in range(nb):
656 new *= expression
657 return new
658 elif expression.vartype == 2:
659
660 nb = self.count(id)
661 for i in range(nb):
662 self.remove(id)
663 self.__imul__(expression)
664 return self
665
666
667
668
669
670
671
672
673
674
675
676 else:
677 raise Exception, 'Cann\'t replace a Variable by %s' % type(expression)
678
679
681 """return the list of variable used in this multiplication"""
682 return ['%s' % KERNEL.objs[n] for n in self]
683
684
685
686
688 """Define the multiplication with different object"""
689
690 if not hasattr(obj, 'vartype'):
691 if obj:
692 return self.__class__(self, obj*self.prefactor)
693 else:
694 return 0
695 elif obj.vartype == 1:
696 new = obj.__class__([], self.prefactor*obj.prefactor)
697 old, self.prefactor = self.prefactor, 1
698 new[:] = [self * term for term in obj]
699 self.prefactor = old
700 return new
701 elif obj.vartype == 4:
702 return NotImplemented
703
704 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
705
706 __rmul__ = __mul__
707
709 """Define the multiplication with different object"""
710
711 if not hasattr(obj, 'vartype'):
712 if obj:
713 self.prefactor *= obj
714 return self
715 else:
716 return 0
717 elif obj.vartype == 1:
718 new = obj.__class__([], self.prefactor * obj.prefactor)
719 self.prefactor = 1
720 new[:] = [self * term for term in obj]
721 return new
722 elif obj.vartype == 4:
723 return NotImplemented
724
725 self.prefactor *= obj.prefactor
726 return array.__iadd__(self, obj)
727
729 out = 1
730 for i in range(value):
731 out *= self
732 return out
733
734
736 """ define the adition with different object"""
737
738 if not obj:
739 return self
740 elif not hasattr(obj, 'vartype') or obj.vartype == 2:
741 new = self.addclass([self, obj])
742 return new
743 else:
744
745 return NotImplemented
746 __radd__ = __add__
747 __iadd__ = __add__
748
750 return self + (-1) * obj
751
753 self.prefactor *=-1
754 return self
755
757 return (-1) * self + obj
758
760 """ ONLY NUMBER DIVISION ALLOWED"""
761 assert not hasattr(obj, 'vartype')
762 self.prefactor /= obj
763 return self
764
765 __div__ = __idiv__
766 __truediv__ = __div__
767
768
770 """ String representation """
771 t = ['%s' % KERNEL.objs[n] for n in self]
772 if self.prefactor != 1:
773 text = '(%s * %s)' % (self.prefactor,' * '.join(t))
774 else:
775 text = '(%s)' % (' * '.join(t))
776 return text
777
778 __rep__ = __str__
779
782
790
794
798
801 """This is the standard object for all the variable linked to expression.
802 """
803 mult_class = MultVariable
804
805 - def __new__(cls, name, baseclass, *args):
806 """Factory class return a MultVariable."""
807
808 if name in KERNEL:
809 return cls.mult_class([KERNEL[name]])
810 else:
811 obj = baseclass(name, *args)
812 id = KERNEL.add(name, obj)
813 obj.id = id
814 return cls.mult_class([id])
815
820
833
834
835
836
837
838
839
840
841
842
843
844
845 -class MultLorentz(MultVariable):
846 """Specific class for LorentzObject Multiplication"""
847
848 add_class = AddVariable
849
851 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining
852 the contraction in this Multiplication."""
853
854 out = {}
855 len_mult = len(self)
856
857 for i, fact in enumerate(self):
858
859 for j in range(len(fact.lorentz_ind)):
860
861 for k in range(i+1,len_mult):
862 fact2 = self[k]
863 try:
864 l = fact2.lorentz_ind.index(fact.lorentz_ind[j])
865 except Exception:
866 pass
867 else:
868 out[(i, j)] = (k, l)
869 out[(k, l)] = (i, j)
870 return out
871
873 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining
874 the contraction in this Multiplication."""
875
876 out = {}
877 len_mult = len(self)
878
879 for i, fact in enumerate(self):
880
881 for j in range(len(fact.spin_ind)):
882
883 for k in range(i+1, len_mult):
884 fact2 = self[k]
885 try:
886 l = fact2.spin_ind.index(fact.spin_ind[j])
887 except Exception:
888 pass
889 else:
890 out[(i, j)] = (k, l)
891 out[(k, l)] = (i, j)
892
893 return out
894
896 """return one variable which are contracted with var and not yet expanded"""
897
898 for var in self.unused:
899 obj = KERNEL.objs[var]
900 if obj.has_component(home.lorentz_ind, home.spin_ind):
901 return obj
902 return None
903
904
905
906
908 """ expand each part of the product and combine them.
909 Try to use a smart order in order to minimize the number of uncontracted indices.
910 Veto forbids the use of sub-expression if it contains some of the variable in the
911 expression. Veto contains the id of the vetoed variables
912 """
913
914 self.unused = self[:]
915
916 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first]
917 product_term = []
918 current = None
919
920 while self.unused:
921
922 if not current:
923
924 try:
925
926 current = basic_end_point.pop()
927 except Exception:
928
929 current = self.unused.pop()
930 else:
931
932 if current not in self.unused:
933 current = None
934 continue
935
936 self.unused.remove(current)
937 cur_obj = KERNEL.objs[current]
938
939 product_term.append(cur_obj.expand())
940
941
942 var_obj = self.neighboor(product_term[-1])
943
944
945 if var_obj:
946 product_term[-1] *= var_obj.expand()
947 cur_obj = var_obj
948 self.unused.remove(cur_obj.id)
949 continue
950
951 current = None
952
953
954
955
956 out = self.prefactor
957 for fact in product_term[:]:
958 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []:
959 scalar = fact.get_rep([0])
960 if hasattr(scalar, 'vartype') and scalar.vartype == 1:
961 if not veto or not scalar.contains(veto):
962 scalar = scalar.simplify()
963 prefactor = 1
964
965 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]:
966 prefactor = scalar.prefactor
967 scalar.prefactor = 1
968 new = KERNEL.add_expression_contraction(scalar)
969 fact.set_rep([0], prefactor * new)
970 out *= fact
971 return out
972
974 """ create a shadow copy """
975 new = MultLorentz(self)
976 new.prefactor = self.prefactor
977 return new
978
983 """ A symbolic Object for All Helas object. All Helas Object Should
984 derivated from this class"""
985
986 contract_first = 0
987 mult_class = MultLorentz
988 add_class = AddVariable
989
990 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
991 """ initialization of the object with default value """
992 assert isinstance(lor_ind, list)
993 assert isinstance(spin_ind, list)
994
995 self.name = name
996 self.lorentz_ind = lor_ind
997 self.spin_ind = spin_ind
998 KERNEL.add_tag(set(tags))
999
1001 """Expand the content information into LorentzObjectRepresentation."""
1002
1003 try:
1004 return self.representation
1005 except Exception:
1006 self.create_representation()
1007 return self.representation
1008
1010 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
1011
1013 """check if this Lorentz Object have some of those indices"""
1014
1015 if any([id in self.lorentz_ind for id in lor_list]) or \
1016 any([id in self.spin_ind for id in spin_list]):
1017 return True
1018
1019
1020
1022 return '%s' % self.name
1023
1025 """ A symbolic Object for All Helas object. All Helas Object Should
1026 derivated from this class"""
1027
1028 mult_class = MultLorentz
1029 object_class = LorentzObject
1030
1034
1035 @classmethod
1037 """default way to have a unique name"""
1038 return '_L_%(class)s_%(args)s' % \
1039 {'class':cls.__name__,
1040 'args': '_'.join(args)
1041 }
1042
1048 """A concrete representation of the LorentzObject."""
1049
1050 vartype = 4
1051
1053 """Specify error for LorentzObjectRepresentation"""
1054
1055 - def __init__(self, representation, lorentz_indices, spin_indices):
1056 """ initialize the lorentz object representation"""
1057
1058 self.lorentz_ind = lorentz_indices
1059 self.nb_lor = len(lorentz_indices)
1060 self.spin_ind = spin_indices
1061 self.nb_spin = len(spin_indices)
1062 self.nb_ind = self.nb_lor + self.nb_spin
1063
1064
1065
1066 if self.lorentz_ind or self.spin_ind:
1067 dict.__init__(self, representation)
1068 elif isinstance(representation,dict):
1069 if len(representation) == 0:
1070 self[(0,)] = 0
1071 elif len(representation) == 1 and (0,) in representation:
1072 self[(0,)] = representation[(0,)]
1073 else:
1074 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.")
1075 else:
1076 if isinstance(representation,dict):
1077 try:
1078 self[(0,)] = representation[(0,)]
1079 except Exception:
1080 if representation:
1081 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.")
1082 else:
1083 self[(0,)] = 0
1084 else:
1085 self[(0,)] = representation
1086
1088 """ string representation """
1089 text = 'lorentz index :' + str(self.lorentz_ind) + '\n'
1090 text += 'spin index :' + str(self.spin_ind) + '\n'
1091
1092 for ind in self.listindices():
1093 ind = tuple(ind)
1094 text += str(ind) + ' --> '
1095 text += str(self.get_rep(ind)) + '\n'
1096 return text
1097
1099 """return the value/Variable associate to the indices"""
1100 return self[tuple(indices)]
1101
1102 - def set_rep(self, indices, value):
1103 """assign 'value' at the indices position"""
1104
1105 self[tuple(indices)] = value
1106
1108 """Return an iterator in order to be able to loop easily on all the
1109 indices of the object."""
1110 return IndicesIterator(self.nb_ind)
1111
1112 @staticmethod
1114 shift = len(switch_order)
1115 for value in l1:
1116 try:
1117 index = l2.index(value)
1118 except Exception:
1119 raise LorentzObjectRepresentation.LorentzObjectRepresentationError(
1120 "Invalid addition. Object doen't have the same lorentz "+ \
1121 "indices : %s != %s" % (l1, l2))
1122 else:
1123 switch_order.append(shift + index)
1124 return switch_order
1125
1126
1128
1129 if not obj:
1130 return self
1131
1132 if not hasattr(obj, 'vartype'):
1133 assert self.lorentz_ind == []
1134 assert self.spin_ind == []
1135 new = self[(0,)] + obj * fact
1136 out = LorentzObjectRepresentation(new, [], [])
1137 return out
1138
1139 assert(obj.vartype == 4 == self.vartype)
1140
1141 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind:
1142
1143 switch_order = []
1144 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order)
1145 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order)
1146 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))])
1147 else:
1148
1149 switch = lambda ind : (ind)
1150
1151
1152 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind))
1153 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind))
1154
1155
1156 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind)
1157
1158
1159 if fact == 1:
1160 for ind in self.listindices():
1161 value = obj.get_rep(ind) + self.get_rep(switch(ind))
1162 new.set_rep(ind, value)
1163 else:
1164 for ind in self.listindices():
1165 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind)
1166 new.set_rep(ind, value)
1167
1168 return new
1169
1171
1172 if not obj:
1173 return self
1174
1175 assert(obj.vartype == 4 == self.vartype)
1176
1177 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind:
1178
1179
1180 switch_order = []
1181 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order)
1182 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order)
1183 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))])
1184 else:
1185
1186 switch = lambda ind : (ind)
1187
1188
1189 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind))
1190 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind))
1191
1192
1193 if fact == 1:
1194 for ind in self.listindices():
1195 self[tuple(ind)] += obj.get_rep(switch(ind))
1196 else:
1197 for ind in self.listindices():
1198 self[tuple(ind)] += fact * obj.get_rep(switch(ind))
1199 return self
1200
1202 return self.__add__(obj, fact= -1)
1203
1205 return obj.__add__(self, fact= -1)
1206
1208 return self.__add__(obj, fact= -1)
1209
1211 self *= -1
1212 return self
1213
1215 """multiplication performing directly the einstein/spin sommation.
1216 """
1217
1218 if not hasattr(obj, 'vartype'):
1219 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind)
1220 for ind in out.listindices():
1221 out.set_rep(ind, obj * self.get_rep(ind))
1222 return out
1223
1224
1225 assert(obj.__class__ == LorentzObjectRepresentation), \
1226 '%s is not valid class for this operation' %type(obj)
1227
1228
1229
1230 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \
1231 obj.lorentz_ind)
1232 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \
1233 obj.spin_ind)
1234 if not(sum_l_ind or sum_s_ind):
1235
1236 return self.tensor_product(obj)
1237
1238
1239
1240 new_object = LorentzObjectRepresentation({}, l_ind, s_ind)
1241
1242 for indices in new_object.listindices():
1243
1244 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind)
1245 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind)
1246
1247 new_object.set_rep(indices, \
1248 self.contraction(obj, sum_l_ind, sum_s_ind, \
1249 dict_l_ind, dict_s_ind))
1250
1251 return new_object
1252
1253 __rmul__ = __mul__
1254 __imul__ = __mul__
1255
1256 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1257 """ make the Lorentz/spin contraction of object self and obj.
1258 l_sum/s_sum are the position of the sum indices
1259 l_dict/s_dict are dict given the value of the fix indices (indices->value)
1260 """
1261 out = 0
1262 len_l = len(l_sum)
1263 len_s = len(s_sum)
1264
1265
1266
1267 for l_value in IndicesIterator(len_l):
1268 l_dict.update(self.pass_ind_in_dict(l_value, l_sum))
1269 for s_value in IndicesIterator(len_s):
1270
1271 s_dict.update(self.pass_ind_in_dict(s_value, s_sum))
1272
1273
1274 self_ind = self.combine_indices(l_dict, s_dict)
1275 obj_ind = obj.combine_indices(l_dict, s_dict)
1276
1277
1278 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind)
1279
1280 if factor:
1281
1282 try:
1283 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0))
1284 except Exception:
1285 factor *= (-1) ** (len(l_value) - l_value.count(0))
1286 out += factor
1287 return out
1288
1290 """ return the tensorial product of the object"""
1291 assert(obj.vartype == 4)
1292
1293 new_object = LorentzObjectRepresentation({}, \
1294 self.lorentz_ind + obj.lorentz_ind, \
1295 self.spin_ind + obj.spin_ind)
1296
1297
1298 lor1 = self.nb_lor
1299 lor2 = obj.nb_lor
1300 spin1 = self.nb_spin
1301 spin2 = obj.nb_spin
1302
1303
1304 if lor1 == 0 == spin1:
1305
1306 selfind = lambda indices: [0]
1307 else:
1308 selfind = lambda indices: indices[:lor1] + \
1309 indices[lor1 + lor2: lor1 + lor2 + spin1]
1310
1311
1312 if lor2 == 0 == spin2:
1313
1314 objind = lambda indices: [0]
1315 else:
1316 objind = lambda indices: indices[lor1: lor1 + lor2] + \
1317 indices[lor1 + lor2 + spin1:]
1318
1319
1320 for indices in new_object.listindices():
1321
1322 fac1 = self.get_rep(tuple(selfind(indices)))
1323 fac2 = obj.get_rep(tuple(objind(indices)))
1324 new_object.set_rep(indices, fac1 * fac2)
1325
1326 return new_object
1327
1329 """Try to factorize each component"""
1330 for ind, fact in self.items():
1331 if fact:
1332 self.set_rep(ind, fact.factorize())
1333
1334
1335 return self
1336
1338 """Check if we can simplify the object (check for non treated Sum)"""
1339
1340
1341 for ind, term in self.items():
1342 if hasattr(term, 'vartype'):
1343 self[ind] = term.simplify()
1344
1345 return self
1346
1347 @staticmethod
1349 """return two list, the first one contains the position of non summed
1350 index and the second one the position of summed index."""
1351
1352
1353
1354
1355
1356
1357 are_unique, are_sum = [], []
1358
1359
1360 for indice in list1:
1361 if indice in list2:
1362 are_sum.append(indice)
1363 else:
1364 are_unique.append(indice)
1365
1366
1367 for indice in list2:
1368 if indice not in are_sum:
1369 are_unique.append(indice)
1370
1371
1372 return are_unique, are_sum
1373
1374 @staticmethod
1376 """made a dictionary (pos -> index_value) for how call the object"""
1377 if not key:
1378 return {}
1379 out = {}
1380 for i, ind in enumerate(indices):
1381 out[key[i]] = ind
1382 return out
1383
1385 """return the indices in the correct order following the dicts rules"""
1386
1387 out = []
1388
1389 for value in self.lorentz_ind:
1390 out.append(l_dict[value])
1391
1392 for value in self.spin_ind:
1393 out.append(s_dict[value])
1394
1395 return out
1396
1397 - def split(self, variables_id):
1398 """return a dict with the key being the power associated to each variables
1399 and the value being the object remaining after the suppression of all
1400 the variable"""
1401
1402 out = SplitCoefficient()
1403 zero_rep = {}
1404 for ind in self.listindices():
1405 zero_rep[tuple(ind)] = 0
1406
1407 for ind in self.listindices():
1408
1409 if isinstance(self.get_rep(ind), numbers.Number):
1410 if tuple([0]*len(variables_id)) in out:
1411 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind)
1412 else:
1413 out[tuple([0]*len(variables_id))] = \
1414 LorentzObjectRepresentation(dict(zero_rep),
1415 self.lorentz_ind, self.spin_ind)
1416 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind)
1417 continue
1418
1419 for key, value in self.get_rep(ind).split(variables_id).items():
1420 if key in out:
1421 out[key][tuple(ind)] += value
1422 else:
1423 out[key] = LorentzObjectRepresentation(dict(zero_rep),
1424 self.lorentz_ind, self.spin_ind)
1425 out[key][tuple(ind)] += value
1426
1427 return out
1428
1436 """Class needed for the iterator"""
1437
1439 """ create an iterator looping over the indices of a list of len "len"
1440 with each value can take value between 0 and 3 """
1441
1442 self.len = len
1443 if len:
1444
1445
1446 self.data = [-1] + [0] * (len - 1)
1447 else:
1448
1449 self.data = 0
1450 self.next = self.nextscalar
1451
1454
1456 for i in range(self.len):
1457 if self.data[i] < 3:
1458 self.data[i] += 1
1459 return self.data
1460 else:
1461 self.data[i] = 0
1462 raise StopIteration
1463
1465 if self.data:
1466 raise StopIteration
1467 else:
1468 self.data = True
1469 return [0]
1470
1472
1476
1478 """return the highest rank of the coefficient"""
1479
1480 return max([max(arg[:4]) for arg in self])
1481
1482
1483 if '__main__' ==__name__:
1484
1485 import cProfile
1489
1490 cProfile.run('create()')
1491