1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 from __future__ import division
46 from __future__ import absolute_import
47 from __future__ import print_function
48 from array import array
49 import collections
50 from fractions import Fraction
51 import numbers
52 import re
53 import aloha
54 from six.moves import range
55
56 try:
57 import madgraph.various.misc as misc
58 except Exception:
59 import aloha.misc as misc
65
67 """ a class to encapsulate all computation. Limit side effect """
68
70 self.objs = []
71 self.use_tag = set()
72 self.id = -1
73 self.reduced_expr = {}
74 self.fct_expr = {}
75 self.reduced_expr2 = {}
76 self.inverted_fct = {}
77 self.has_pi = False
78 self.unknow_fct = []
79 dict.__init__(self)
80
84
85 - def add(self, name, obj):
86 self.id += 1
87 self.objs.append(obj)
88 self[name] = self.id
89 return self.id
90
91 - def get(self, name):
92 return self.objs[self[name]]
93
96
98 """return the list of identification number associate to the
99 given variables names. If a variable didn't exists, create it (in complex).
100 """
101 out = []
102 for var in variables:
103 try:
104 id = self[var]
105 except KeyError:
106 assert var not in ['M','W']
107 id = Variable(var).get_id()
108 out.append(id)
109 return out
110
111
113
114 str_expr = str(expression)
115 if str_expr in self.reduced_expr:
116 out, tag = self.reduced_expr[str_expr]
117 self.add_tag((tag,))
118 return out
119 if expression == 0:
120 return 0
121 new_2 = expression.simplify()
122 if new_2 == 0:
123 return 0
124
125 tag = 'TMP%s' % len(self.reduced_expr)
126 new = Variable(tag)
127 self.reduced_expr[str_expr] = [new, tag]
128 new_2 = new_2.factorize()
129 self.reduced_expr2[tag] = new_2
130 self.add_tag((tag,))
131
132
133 return new
134
135 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot',
136 'theta_function', 'exp']
138
139
140 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or
141 (fct_tag, len(args)) in self.unknow_fct):
142 self.unknow_fct.append( (fct_tag, len(args)) )
143
144 argument = []
145 for expression in args:
146 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)):
147 try:
148 expr = expression.expand().get_rep([0])
149 except KeyError as error:
150 if error.args != ((0,),):
151 raise
152 else:
153 raise aloha.ALOHAERROR('''Error in input format.
154 Argument of function (or denominator) should be scalar.
155 We found %s''' % expression)
156 new = expr.simplify()
157 if not isinstance(new, numbers.Number):
158 new = new.factorize()
159 argument.append(new)
160 else:
161 argument.append(expression)
162
163 for arg in argument:
164 val = re.findall(r'''\bFCT(\d*)\b''', str(arg))
165 for v in val:
166 self.add_tag(('FCT%s' % v,))
167
168
169 if (fct_tag.startswith('cmath.') or fct_tag in self.known_fct) and \
170 all(isinstance(x, numbers.Number) for x in argument):
171 import cmath
172 if fct_tag.startswith('cmath.'):
173 module = ''
174 else:
175 module = 'cmath.'
176 try:
177 return str(eval("%s%s(%s)" % (module,fct_tag, ','.join(repr(x) for x in argument))))
178 except Exception as error:
179 print(error)
180 print("cmath.%s(%s)" % (fct_tag, ','.join(repr(x) for x in argument)))
181 if str(fct_tag)+str(argument) in self.inverted_fct:
182 tag = self.inverted_fct[str(fct_tag)+str(argument)]
183 v = tag.split('(')[1][:-1]
184 self.add_tag(('FCT%s' % v,))
185 return tag
186 else:
187 id = len(self.fct_expr)
188 tag = 'FCT%s' % id
189 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id
190 self.fct_expr[tag] = (fct_tag, argument)
191 self.reduced_expr2[tag] = (fct_tag, argument)
192 self.add_tag((tag,))
193
194 return 'FCT(%s)' % id
195
196 KERNEL = Computation()
202 """ A list of Variable/ConstantObject/... This object represent the operation
203 between those object."""
204
205
206 vartype = 1
207
208 - def __init__(self, old_data=[], prefactor=1):
209 """ initialization of the object with default value """
210
211 self.prefactor = prefactor
212
213 list.__init__(self, old_data)
214
216 """ apply rule of simplification """
217
218
219 if len(self) == 1:
220 return self.prefactor * self[0].simplify()
221 constant = 0
222 items = {}
223 pos = -1
224 for term in self[:]:
225 pos += 1
226 if not hasattr(term, 'vartype'):
227 if isinstance(term, dict):
228
229 assert list(term.values()) == [0]
230 term = term[(0,)]
231 constant += term
232 del self[pos]
233 pos -= 1
234 continue
235 tag = tuple(term.sort())
236 if tag in items:
237 orig_prefac = items[tag].prefactor
238 items[tag].prefactor += term.prefactor
239 if items[tag].prefactor and \
240 abs(items[tag].prefactor) / (abs(orig_prefac)+abs(term.prefactor)) < 1e-8:
241 items[tag].prefactor = 0
242 del self[pos]
243 pos -=1
244 else:
245 items[tag] = term.__class__(term, term.prefactor)
246 self[pos] = items[tag]
247
248
249 countprefact = defaultdict(int)
250 nbplus, nbminus = 0,0
251 if constant not in [0, 1,-1]:
252 countprefact[constant] += 1
253 if constant.real + constant.imag > 0:
254 nbplus += 1
255 else:
256 nbminus += 1
257
258 for var in items.values():
259 if var.prefactor == 0:
260 self.remove(var)
261 else:
262 nb = var.prefactor
263 if nb in [1,-1]:
264 continue
265 countprefact[abs(nb)] +=1
266 if nb.real + nb.imag > 0:
267 nbplus += 1
268 else:
269 nbminus += 1
270 if countprefact and max(countprefact.values()) >1:
271 fact_prefactor = sorted(list(countprefact.items()), key=lambda x: x[1], reverse=True)[0][0]
272 else:
273 fact_prefactor = 1
274 if nbplus < nbminus:
275 fact_prefactor *= -1
276 self.prefactor *= fact_prefactor
277
278 if fact_prefactor != 1:
279 for i,a in enumerate(self):
280 try:
281 a.prefactor /= fact_prefactor
282 except AttributeError:
283 self[i] /= fact_prefactor
284
285 if constant:
286 self.append(constant/ fact_prefactor )
287
288
289 varlen = len(self)
290 if varlen == 1:
291 if hasattr(self[0], 'vartype'):
292 return self.prefactor * self[0].simplify()
293 else:
294
295 return self.prefactor * self[0]
296 elif varlen == 0:
297 return 0
298 return self
299
300 - def split(self, variables_id):
301 """return a dict with the key being the power associated to each variables
302 and the value being the object remaining after the suppression of all
303 the variable"""
304
305 out = defaultdict(int)
306 for obj in self:
307 for key, value in obj.split(variables_id).items():
308 out[key] += self.prefactor * value
309 return out
310
312 """returns true if one of the variables is in the expression"""
313
314 return any((v in obj for obj in self for v in variables ))
315
316
318
319 out = []
320 for term in self:
321 if hasattr(term, 'get_all_var_names'):
322 out += term.get_all_var_names()
323 return out
324
325
326
327 - def replace(self, id, expression):
328 """replace one object (identify by his id) by a given expression.
329 Note that expression cann't be zero.
330 Note that this should be canonical form (this should contains ONLY
331 MULTVARIABLE) --so this should be called before a factorize.
332 """
333 new = self.__class__()
334
335 for obj in self:
336 assert isinstance(obj, MultVariable)
337 tmp = obj.replace(id, expression)
338 new += tmp
339 new.prefactor = self.prefactor
340 return new
341
342
344 """Pass from High level object to low level object"""
345
346 if not self:
347 return self
348 if self.prefactor == 1:
349 new = self[0].expand(veto)
350 else:
351 new = self.prefactor * self[0].expand(veto)
352
353 for item in self[1:]:
354 if self.prefactor == 1:
355 try:
356 new += item.expand(veto)
357 except AttributeError:
358 new = new + item
359
360 else:
361 new += (self.prefactor) * item.expand(veto)
362 return new
363
365 """define the multiplication of
366 - a AddVariable with a number
367 - a AddVariable with an AddVariable
368 other type of multiplication are define via the symmetric operation base
369 on the obj class."""
370
371
372 if not hasattr(obj, 'vartype'):
373 if not obj:
374 return 0
375 return self.__class__(self, self.prefactor*obj)
376 elif obj.vartype == 1:
377 new = self.__class__([],self.prefactor * obj.prefactor)
378 new[:] = [i*j for i in self for j in obj]
379 return new
380 else:
381
382 return NotImplemented
383
385 """define the multiplication of
386 - a AddVariable with a number
387 - a AddVariable with an AddVariable
388 other type of multiplication are define via the symmetric operation base
389 on the obj class."""
390
391 if not hasattr(obj, 'vartype'):
392 if not obj:
393 return 0
394 self.prefactor *= obj
395 return self
396 elif obj.vartype == 1:
397 new = self.__class__([], self.prefactor * obj.prefactor)
398 new[:] = [i*j for i in self for j in obj]
399 return new
400 else:
401
402 return NotImplemented
403
405 self.prefactor *= -1
406 return self
407
409 """Define all the different addition."""
410
411 if not hasattr(obj, 'vartype'):
412 if not obj:
413 return self
414 new = self.__class__(self, self.prefactor)
415 new.append(obj/self.prefactor)
416 return new
417 elif obj.vartype == 2:
418 new = AddVariable(self, self.prefactor)
419 if self.prefactor == 1:
420 new.append(obj)
421 else:
422 new.append((1/self.prefactor)*obj)
423 return new
424 elif obj.vartype == 1:
425 new = AddVariable(self, self.prefactor)
426 for item in obj:
427 new.append(obj.prefactor/self.prefactor * item)
428 return new
429 else:
430
431 return NotImplemented
432
434 """Define all the different addition."""
435
436 if not hasattr(obj, 'vartype'):
437 if not obj:
438 return self
439 self.append(obj/self.prefactor)
440 return self
441 elif obj.vartype == 2:
442 if self.prefactor == 1:
443 self.append(obj)
444 else:
445 self.append((1/self.prefactor)*obj)
446 return self
447 elif obj.vartype == 1:
448 for item in obj:
449 self.append(obj.prefactor/self.prefactor * item)
450 return self
451 else:
452
453 return NotImplemented
454
456 return self + (-1) * obj
457
459 return (-1) * self + obj
460
461 __radd__ = __add__
462 __rmul__ = __mul__
463
464
467
468 __truediv__ = __div__
469
471 return self.__rmult__(1/obj)
472
474 text = ''
475 if self.prefactor != 1:
476 text += str(self.prefactor) + ' * '
477 text += '( '
478 text += ' + '.join([str(item) for item in self])
479 text += ' )'
480 return text
481
483 text = ''
484 if self.prefactor != 1:
485 text += str(self.prefactor) + ' * '
486 text += super(AddVariable,self).__repr__()
487 return text
488
490
491
492 count = defaultdict(int)
493 correlation = defaultdict(defaultdict(int))
494 for i,term in enumerate(self):
495 try:
496 set_term = set(term)
497 except TypeError:
498
499 continue
500 for val1 in set_term:
501 count[val1] +=1
502
503 for val2 in set_term:
504 correlation[val1][val2] += 1
505
506 maxnb = max(count.values()) if count else 0
507 possibility = [v for v,val in count.items() if val == maxnb]
508 if maxnb == 1:
509 return 1, None
510 elif len(possibility) == 1:
511 return maxnb, possibility[0]
512
513
514
515
516 max_wgt, maxvar = 0, None
517 for var in possibility:
518 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var])
519 if wgt > max_wgt:
520 maxvar = var
521 max_wgt = wgt
522 str_maxvar = str(KERNEL.objs[var])
523 elif wgt == max_wgt:
524
525 new_str = str(KERNEL.objs[var])
526 if new_str < str_maxvar:
527 maxvar = var
528 str_maxvar = new_str
529 return maxnb, maxvar
530
532 """ try to factorize as much as possible the expression """
533
534 max, maxvar = self.count_term()
535 if max <= 1:
536
537 return self
538 else:
539
540 newadd = AddVariable()
541 constant = AddVariable()
542
543 for term in self:
544 try:
545 term.remove(maxvar)
546 except Exception:
547 constant.append(term)
548 else:
549 if len(term):
550 newadd.append(term)
551 else:
552 newadd.append(term.prefactor)
553 newadd = newadd.factorize()
554
555
556 if isinstance(newadd, AddVariable):
557 countprefact = defaultdict(int)
558 nbplus, nbminus = 0,0
559 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]:
560 countprefact[abs(nb)] +=1
561 if nb.real + nb.imag > 0:
562 nbplus += 1
563 else:
564 nbminus += 1
565
566 newadd.prefactor = sorted(list(countprefact.items()), key=lambda x: x[1], reverse=True)[0][0]
567 if nbplus < nbminus:
568 newadd.prefactor *= -1
569 if newadd.prefactor != 1:
570 for i,a in enumerate(newadd):
571 try:
572 a.prefactor /= newadd.prefactor
573 except AttributeError:
574 newadd[i] /= newadd.prefactor
575
576
577 if len(constant) > 1:
578 constant = constant.factorize()
579 elif constant:
580 constant = constant[0]
581 else:
582 out = MultContainer([KERNEL.objs[maxvar], newadd])
583 out.prefactor = self.prefactor
584 if newadd.prefactor != 1:
585 out.prefactor *= newadd.prefactor
586 newadd.prefactor = 1
587 return out
588 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant],
589 self.prefactor)
590 return out
591
593
594 vartype = 6
595
599
601 """ String representation """
602 if self.prefactor !=1:
603 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self]))
604 else:
605 text = '(%s)' % (' * '.join([str(t) for t in self]))
606 return text
607
609 self[:] = [term.factorize() for term in self]
610
613 """ A list of Variable with multiplication as operator between themselves.
614 Represented by array for speed optimization
615 """
616 vartype=2
617 addclass = AddVariable
618
619 - def __new__(cls, old=[], prefactor=1):
620 return array.__new__(cls, 'i', old)
621
622
623 - def __init__(self, old=[], prefactor=1):
624 """ initialization of the object with default value """
625
626 self.prefactor = prefactor
627 assert isinstance(self.prefactor, (float,int,int,complex))
628
630 assert len(self) == 1
631 return self[0]
632
634 a = list(self)
635 a.sort()
636 self[:] = array('i',a)
637 return self
638
640 """ simplify the product"""
641 if not len(self):
642 return self.prefactor
643 return self
644
645 - def split(self, variables_id):
646 """return a dict with the key being the power associated to each variables
647 and the value being the object remaining after the suppression of all
648 the variable"""
649
650 key = tuple([self.count(i) for i in variables_id])
651 arg = [id for id in self if id not in variables_id]
652 self[:] = array('i', arg)
653 return SplitCoefficient([(key,self)])
654
655 - def replace(self, id, expression):
656 """replace one object (identify by his id) by a given expression.
657 Note that expression cann't be zero.
658 """
659 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult'
660
661 if expression.vartype == 1:
662 nb = self.count(id)
663 if not nb:
664 return self
665 for i in range(nb):
666 self.remove(id)
667 new = self
668 for i in range(nb):
669 new *= expression
670 return new
671 elif expression.vartype == 2:
672
673 nb = self.count(id)
674 for i in range(nb):
675 self.remove(id)
676 self.__imul__(expression)
677 return self
678
679
680
681
682
683
684
685
686
687
688
689 else:
690 raise Exception('Cann\'t replace a Variable by %s' % type(expression))
691
692
694 """return the list of variable used in this multiplication"""
695 return ['%s' % KERNEL.objs[n] for n in self]
696
697
698
699
701 """Define the multiplication with different object"""
702
703 if not hasattr(obj, 'vartype'):
704 if obj:
705 return self.__class__(self, obj*self.prefactor)
706 else:
707 return 0
708 elif obj.vartype == 1:
709 new = obj.__class__([], self.prefactor*obj.prefactor)
710 old, self.prefactor = self.prefactor, 1
711 new[:] = [self * term for term in obj]
712 self.prefactor = old
713 return new
714 elif obj.vartype == 4:
715 return NotImplemented
716
717 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
718
719 __rmul__ = __mul__
720
722 """Define the multiplication with different object"""
723
724 if not hasattr(obj, 'vartype'):
725 if obj:
726 self.prefactor *= obj
727 return self
728 else:
729 return 0
730 elif obj.vartype == 1:
731 new = obj.__class__([], self.prefactor * obj.prefactor)
732 self.prefactor = 1
733 new[:] = [self * term for term in obj]
734 return new
735 elif obj.vartype == 4:
736 return NotImplemented
737
738 self.prefactor *= obj.prefactor
739 return array.__iadd__(self, obj)
740
742 out = 1
743 for i in range(value):
744 out *= self
745 return out
746
747
749 """ define the adition with different object"""
750
751 if not obj:
752 return self
753 elif not hasattr(obj, 'vartype') or obj.vartype == 2:
754 new = self.addclass([self, obj])
755 return new
756 else:
757
758 return NotImplemented
759 __radd__ = __add__
760 __iadd__ = __add__
761
763 return self + (-1) * obj
764
766 self.prefactor *=-1
767 return self
768
770 return (-1) * self + obj
771
773 """ ONLY NUMBER DIVISION ALLOWED"""
774 assert not hasattr(obj, 'vartype')
775 self.prefactor /= obj
776 return self
777
778 __div__ = __idiv__
779 __truediv__ = __div__
780
781
783 """ String representation """
784 t = ['%s' % KERNEL.objs[n] for n in self]
785 if self.prefactor != 1:
786 text = '(%s * %s)' % (self.prefactor,' * '.join(t))
787 else:
788 text = '(%s)' % (' * '.join(t))
789 return text
790
791 __repr__ = __str__
792
795
803
807
811
814 """This is the standard object for all the variable linked to expression.
815 """
816 mult_class = MultVariable
817
818 - def __new__(cls, name, baseclass, *args):
819 """Factory class return a MultVariable."""
820
821 if name in KERNEL:
822 return cls.mult_class([KERNEL[name]])
823 else:
824 obj = baseclass(name, *args)
825 id = KERNEL.add(name, obj)
826 obj.id = id
827 return cls.mult_class([id])
828
833
846
847
848
849
850
851
852
853
854
855
856
857
858 -class MultLorentz(MultVariable):
859 """Specific class for LorentzObject Multiplication"""
860
861 add_class = AddVariable
862
864 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining
865 the contraction in this Multiplication."""
866
867 out = {}
868 len_mult = len(self)
869
870 for i, fact in enumerate(self):
871
872 for j in range(len(fact.lorentz_ind)):
873
874 for k in range(i+1,len_mult):
875 fact2 = self[k]
876 try:
877 l = fact2.lorentz_ind.index(fact.lorentz_ind[j])
878 except Exception:
879 pass
880 else:
881 out[(i, j)] = (k, l)
882 out[(k, l)] = (i, j)
883 return out
884
886 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining
887 the contraction in this Multiplication."""
888
889 out = {}
890 len_mult = len(self)
891
892 for i, fact in enumerate(self):
893
894 for j in range(len(fact.spin_ind)):
895
896 for k in range(i+1, len_mult):
897 fact2 = self[k]
898 try:
899 l = fact2.spin_ind.index(fact.spin_ind[j])
900 except Exception:
901 pass
902 else:
903 out[(i, j)] = (k, l)
904 out[(k, l)] = (i, j)
905
906 return out
907
909 """return one variable which are contracted with var and not yet expanded"""
910
911 for var in self.unused:
912 obj = KERNEL.objs[var]
913 if obj.has_component(home.lorentz_ind, home.spin_ind):
914 return obj
915 return None
916
917
918
919
921 """ expand each part of the product and combine them.
922 Try to use a smart order in order to minimize the number of uncontracted indices.
923 Veto forbids the use of sub-expression if it contains some of the variable in the
924 expression. Veto contains the id of the vetoed variables
925 """
926
927 self.unused = self[:]
928
929 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first]
930 product_term = []
931 current = None
932
933 while self.unused:
934
935 if not current:
936
937 try:
938
939 current = basic_end_point.pop()
940 except Exception:
941
942 current = self.unused.pop()
943 else:
944
945 if current not in self.unused:
946 current = None
947 continue
948
949 self.unused.remove(current)
950 cur_obj = KERNEL.objs[current]
951
952 product_term.append(cur_obj.expand())
953
954
955 var_obj = self.neighboor(product_term[-1])
956
957
958 if var_obj:
959 product_term[-1] *= var_obj.expand()
960 cur_obj = var_obj
961 self.unused.remove(cur_obj.id)
962 continue
963
964 current = None
965
966
967
968
969 out = self.prefactor
970 for fact in product_term[:]:
971 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []:
972 scalar = fact.get_rep([0])
973 if hasattr(scalar, 'vartype') and scalar.vartype == 1:
974 if not veto or not scalar.contains(veto):
975 scalar = scalar.simplify()
976 prefactor = 1
977
978 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]:
979 prefactor = scalar.prefactor
980 scalar.prefactor = 1
981 new = KERNEL.add_expression_contraction(scalar)
982 fact.set_rep([0], prefactor * new)
983 out *= fact
984 return out
985
987 """ create a shadow copy """
988 new = MultLorentz(self)
989 new.prefactor = self.prefactor
990 return new
991
996 """ A symbolic Object for All Helas object. All Helas Object Should
997 derivated from this class"""
998
999 contract_first = 0
1000 mult_class = MultLorentz
1001 add_class = AddVariable
1002
1003 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
1004 """ initialization of the object with default value """
1005 assert isinstance(lor_ind, list)
1006 assert isinstance(spin_ind, list)
1007
1008 self.name = name
1009 self.lorentz_ind = lor_ind
1010 self.spin_ind = spin_ind
1011 KERNEL.add_tag(set(tags))
1012
1014 """Expand the content information into LorentzObjectRepresentation."""
1015
1016 try:
1017 return self.representation
1018 except Exception:
1019 self.create_representation()
1020 return self.representation
1021
1023 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
1024
1026 """check if this Lorentz Object have some of those indices"""
1027
1028 if any([id in self.lorentz_ind for id in lor_list]) or \
1029 any([id in self.spin_ind for id in spin_list]):
1030 return True
1031
1032
1033
1035 return '%s' % self.name
1036
1038 """ A symbolic Object for All Helas object. All Helas Object Should
1039 derivated from this class"""
1040
1041 mult_class = MultLorentz
1042 object_class = LorentzObject
1043
1047
1048 @classmethod
1050 """default way to have a unique name"""
1051 return '_L_%(class)s_%(args)s' % \
1052 {'class':cls.__name__,
1053 'args': '_'.join(args)
1054 }
1055
1061 """A concrete representation of the LorentzObject."""
1062
1063 vartype = 4
1064
1066 """Specify error for LorentzObjectRepresentation"""
1067
1068 - def __init__(self, representation, lorentz_indices, spin_indices):
1069 """ initialize the lorentz object representation"""
1070
1071 self.lorentz_ind = lorentz_indices
1072 self.nb_lor = len(lorentz_indices)
1073 self.spin_ind = spin_indices
1074 self.nb_spin = len(spin_indices)
1075 self.nb_ind = self.nb_lor + self.nb_spin
1076
1077
1078
1079 if self.lorentz_ind or self.spin_ind:
1080 dict.__init__(self, representation)
1081 elif isinstance(representation,dict):
1082 if len(representation) == 0:
1083 self[(0,)] = 0
1084 elif len(representation) == 1 and (0,) in representation:
1085 self[(0,)] = representation[(0,)]
1086 else:
1087 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.")
1088 else:
1089 if isinstance(representation,dict):
1090 try:
1091 self[(0,)] = representation[(0,)]
1092 except Exception:
1093 if representation:
1094 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.")
1095 else:
1096 self[(0,)] = 0
1097 else:
1098 self[(0,)] = representation
1099
1101 """ string representation """
1102 text = 'lorentz index :' + str(self.lorentz_ind) + '\n'
1103 text += 'spin index :' + str(self.spin_ind) + '\n'
1104
1105 for ind in self.listindices():
1106 ind = tuple(ind)
1107 text += str(ind) + ' --> '
1108 text += str(self.get_rep(ind)) + '\n'
1109 return text
1110
1112 """return the value/Variable associate to the indices"""
1113 return self[tuple(indices)]
1114
1115 - def set_rep(self, indices, value):
1116 """assign 'value' at the indices position"""
1117
1118 self[tuple(indices)] = value
1119
1121 """Return an iterator in order to be able to loop easily on all the
1122 indices of the object."""
1123 return IndicesIterator(self.nb_ind)
1124
1125 @staticmethod
1127 shift = len(switch_order)
1128 for value in l1:
1129 try:
1130 index = l2.index(value)
1131 except Exception:
1132 raise LorentzObjectRepresentation.LorentzObjectRepresentationError(
1133 "Invalid addition. Object doen't have the same lorentz "+ \
1134 "indices : %s != %s" % (l1, l2))
1135 else:
1136 switch_order.append(shift + index)
1137 return switch_order
1138
1139
1141
1142 if not obj:
1143 return self
1144
1145 if not hasattr(obj, 'vartype'):
1146 assert self.lorentz_ind == []
1147 assert self.spin_ind == []
1148 new = self[(0,)] + obj * fact
1149 out = LorentzObjectRepresentation(new, [], [])
1150 return out
1151
1152 assert(obj.vartype == 4 == self.vartype)
1153
1154 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind:
1155
1156 switch_order = []
1157 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order)
1158 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order)
1159 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))])
1160 else:
1161
1162 switch = lambda ind : (ind)
1163
1164
1165 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind))
1166 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind))
1167
1168
1169 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind)
1170
1171
1172 if fact == 1:
1173 for ind in self.listindices():
1174 value = obj.get_rep(ind) + self.get_rep(switch(ind))
1175 new.set_rep(ind, value)
1176 else:
1177 for ind in self.listindices():
1178 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind)
1179 new.set_rep(ind, value)
1180
1181 return new
1182
1184
1185 if not obj:
1186 return self
1187
1188 assert(obj.vartype == 4 == self.vartype)
1189
1190 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind:
1191
1192
1193 switch_order = []
1194 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order)
1195 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order)
1196 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))])
1197 else:
1198
1199 switch = lambda ind : (ind)
1200
1201
1202 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind))
1203 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind))
1204
1205
1206 if fact == 1:
1207 for ind in self.listindices():
1208 self[tuple(ind)] += obj.get_rep(switch(ind))
1209 else:
1210 for ind in self.listindices():
1211 self[tuple(ind)] += fact * obj.get_rep(switch(ind))
1212 return self
1213
1215 return self.__add__(obj, fact= -1)
1216
1218 return obj.__add__(self, fact= -1)
1219
1221 return self.__add__(obj, fact= -1)
1222
1224 self *= -1
1225 return self
1226
1228 """multiplication performing directly the einstein/spin sommation.
1229 """
1230
1231 if not hasattr(obj, 'vartype'):
1232 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind)
1233 for ind in out.listindices():
1234 out.set_rep(ind, obj * self.get_rep(ind))
1235 return out
1236
1237
1238 assert(obj.__class__ == LorentzObjectRepresentation), \
1239 '%s is not valid class for this operation' %type(obj)
1240
1241
1242
1243 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \
1244 obj.lorentz_ind)
1245 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \
1246 obj.spin_ind)
1247 if not(sum_l_ind or sum_s_ind):
1248
1249 return self.tensor_product(obj)
1250
1251
1252
1253 new_object = LorentzObjectRepresentation({}, l_ind, s_ind)
1254
1255 for indices in new_object.listindices():
1256
1257 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind)
1258 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind)
1259
1260 new_object.set_rep(indices, \
1261 self.contraction(obj, sum_l_ind, sum_s_ind, \
1262 dict_l_ind, dict_s_ind))
1263
1264 return new_object
1265
1266 __rmul__ = __mul__
1267 __imul__ = __mul__
1268
1269 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1270 """ make the Lorentz/spin contraction of object self and obj.
1271 l_sum/s_sum are the position of the sum indices
1272 l_dict/s_dict are dict given the value of the fix indices (indices->value)
1273 """
1274 out = 0
1275 len_l = len(l_sum)
1276 len_s = len(s_sum)
1277
1278
1279
1280 for l_value in IndicesIterator(len_l):
1281 l_dict.update(self.pass_ind_in_dict(l_value, l_sum))
1282 for s_value in IndicesIterator(len_s):
1283
1284 s_dict.update(self.pass_ind_in_dict(s_value, s_sum))
1285
1286
1287 self_ind = self.combine_indices(l_dict, s_dict)
1288 obj_ind = obj.combine_indices(l_dict, s_dict)
1289
1290
1291 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind)
1292
1293 if factor:
1294
1295 try:
1296 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0))
1297 except Exception:
1298 factor *= (-1) ** (len(l_value) - l_value.count(0))
1299 out += factor
1300 return out
1301
1303 """ return the tensorial product of the object"""
1304 assert(obj.vartype == 4)
1305
1306 new_object = LorentzObjectRepresentation({}, \
1307 self.lorentz_ind + obj.lorentz_ind, \
1308 self.spin_ind + obj.spin_ind)
1309
1310
1311 lor1 = self.nb_lor
1312 lor2 = obj.nb_lor
1313 spin1 = self.nb_spin
1314 spin2 = obj.nb_spin
1315
1316
1317 if lor1 == 0 == spin1:
1318
1319 selfind = lambda indices: [0]
1320 else:
1321 selfind = lambda indices: indices[:lor1] + \
1322 indices[lor1 + lor2: lor1 + lor2 + spin1]
1323
1324
1325 if lor2 == 0 == spin2:
1326
1327 objind = lambda indices: [0]
1328 else:
1329 objind = lambda indices: indices[lor1: lor1 + lor2] + \
1330 indices[lor1 + lor2 + spin1:]
1331
1332
1333 for indices in new_object.listindices():
1334
1335 fac1 = self.get_rep(tuple(selfind(indices)))
1336 fac2 = obj.get_rep(tuple(objind(indices)))
1337 new_object.set_rep(indices, fac1 * fac2)
1338
1339 return new_object
1340
1342 """Try to factorize each component"""
1343 for ind, fact in self.items():
1344 if fact:
1345 self.set_rep(ind, fact.factorize())
1346
1347
1348 return self
1349
1351 """Check if we can simplify the object (check for non treated Sum)"""
1352
1353
1354 for ind, term in self.items():
1355 if hasattr(term, 'vartype'):
1356 self[ind] = term.simplify()
1357
1358 return self
1359
1360 @staticmethod
1362 """return two list, the first one contains the position of non summed
1363 index and the second one the position of summed index."""
1364
1365
1366
1367
1368
1369
1370 are_unique, are_sum = [], []
1371
1372
1373 for indice in list1:
1374 if indice in list2:
1375 are_sum.append(indice)
1376 else:
1377 are_unique.append(indice)
1378
1379
1380 for indice in list2:
1381 if indice not in are_sum:
1382 are_unique.append(indice)
1383
1384
1385 return are_unique, are_sum
1386
1387 @staticmethod
1389 """made a dictionary (pos -> index_value) for how call the object"""
1390 if not key:
1391 return {}
1392 out = {}
1393 for i, ind in enumerate(indices):
1394 out[key[i]] = ind
1395 return out
1396
1398 """return the indices in the correct order following the dicts rules"""
1399
1400 out = []
1401
1402 for value in self.lorentz_ind:
1403 out.append(l_dict[value])
1404
1405 for value in self.spin_ind:
1406 out.append(s_dict[value])
1407
1408 return out
1409
1410 - def split(self, variables_id):
1411 """return a dict with the key being the power associated to each variables
1412 and the value being the object remaining after the suppression of all
1413 the variable"""
1414
1415 out = SplitCoefficient()
1416 zero_rep = {}
1417 for ind in self.listindices():
1418 zero_rep[tuple(ind)] = 0
1419
1420 for ind in self.listindices():
1421
1422 if isinstance(self.get_rep(ind), numbers.Number):
1423 if tuple([0]*len(variables_id)) in out:
1424 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind)
1425 else:
1426 out[tuple([0]*len(variables_id))] = \
1427 LorentzObjectRepresentation(dict(zero_rep),
1428 self.lorentz_ind, self.spin_ind)
1429 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind)
1430 continue
1431
1432 for key, value in self.get_rep(ind).split(variables_id).items():
1433 if key in out:
1434 out[key][tuple(ind)] += value
1435 else:
1436 out[key] = LorentzObjectRepresentation(dict(zero_rep),
1437 self.lorentz_ind, self.spin_ind)
1438 out[key][tuple(ind)] += value
1439
1440 return out
1441
1449 """Class needed for the iterator"""
1450
1452 """ create an iterator looping over the indices of a list of len "len"
1453 with each value can take value between 0 and 3 """
1454
1455 self.len = len
1456 if len:
1457
1458
1459 self.data = [-1] + [0] * (len - 1)
1460 else:
1461
1462 self.data = 0
1463
1466
1468 if not self.len:
1469 return self.nextscalar()
1470
1471 for i in range(self.len):
1472 if self.data[i] < 3:
1473 self.data[i] += 1
1474 return self.data
1475 else:
1476 self.data[i] = 0
1477 raise StopIteration
1478
1479 next = __next__
1480
1482 if self.data:
1483 raise StopIteration
1484 else:
1485 self.data = True
1486 return [0]
1487
1489
1493
1495 """return the highest rank of the coefficient"""
1496
1497 return max([max(arg[:4]) for arg in self])
1498
1499
1500 if '__main__' ==__name__:
1501
1502 import cProfile
1506
1507 cProfile.run('create()')
1508