Package aloha :: Module aloha_lib
[hide private]
[frames] | no frames]

Source Code for Module aloha.aloha_lib

   1  ################################################################################ 
   2  # 
   3  # Copyright (c) 2010 The MadGraph5_aMC@NLO Development team and Contributors 
   4  # 
   5  # This file is a part of the MadGraph5_aMC@NLO project, an application which  
   6  # automatically generates Feynman diagrams and matrix elements for arbitrary 
   7  # high-energy processes in the Standard Model and beyond. 
   8  # 
   9  # It is subject to the MadGraph5_aMC@NLO license which should accompany this  
  10  # distribution. 
  11  # 
  12  # For more information, visit madgraph.phys.ucl.ac.be and amcatnlo.web.cern.ch 
  13  # 
  14  ################################################################################ 
  15  ##   Diagram of Class 
  16  ## 
  17  ##    Variable (vartype:0)<--- ScalarVariable  
  18  ##                          | 
  19  ##                          +- LorentzObject  
  20  ##                                 
  21  ## 
  22  ##    list <--- AddVariable (vartype :1)    
  23  ##            
  24  ##    array <--- MultVariable  <--- MultLorentz (vartype:2)  
  25  ##            
  26  ##    list <--- LorentzObjectRepresentation (vartype :4) <-- ConstantObject 
  27  ##                                                               (vartype:5) 
  28  ## 
  29  ##    FracVariable (vartype:3) 
  30  ## 
  31  ##    MultContainer (vartype:6) 
  32  ## 
  33  ################################################################################ 
  34  ## 
  35  ##   Variable is in fact Factory wich adds a references to the variable name 
  36  ##   Into the KERNEL (Of class Computation) instantiate a real variable object 
  37  ##   (of class C_Variable, DVariable for complex/real) and return a MUltVariable 
  38  ##   with a single element. 
  39  ## 
  40  ##   Lorentz Object works in the same way. 
  41  ## 
  42  ################################################################################ 
  43   
  44   
  45  from __future__ import division 
  46  from __future__ import absolute_import 
  47  from __future__ import print_function 
  48  from array import array 
  49  import collections 
  50  from fractions import Fraction 
  51  import numbers 
  52  import re 
  53  import aloha # define mode of writting 
  54  from six.moves import range 
  55   
  56  try: 
  57      import madgraph.various.misc as misc 
  58  except Exception: 
  59      import aloha.misc as misc 
60 61 -class defaultdict(collections.defaultdict):
62
63 - def __call__(self, *args):
64 return defaultdict(int)
65
66 -class Computation(dict):
67 """ a class to encapsulate all computation. Limit side effect """ 68
69 - def __init__(self):
70 self.objs = [] 71 self.use_tag = set() 72 self.id = -1 73 self.reduced_expr = {} 74 self.fct_expr = {} 75 self.reduced_expr2 = {} 76 self.inverted_fct = {} 77 self.has_pi = False # logical to check if pi is used in at least one fct 78 self.unknow_fct = [] 79 dict.__init__(self)
80
81 - def clean(self):
82 self.__init__() 83 self.clear()
84
85 - def add(self, name, obj):
86 self.id += 1 87 self.objs.append(obj) 88 self[name] = self.id 89 return self.id
90
91 - def get(self, name):
92 return self.objs[self[name]]
93
94 - def add_tag(self, tag):
95 self.use_tag.update(tag)
96
97 - def get_ids(self, variables):
98 """return the list of identification number associate to the 99 given variables names. If a variable didn't exists, create it (in complex). 100 """ 101 out = [] 102 for var in variables: 103 try: 104 id = self[var] 105 except KeyError: 106 assert var not in ['M','W'] 107 id = Variable(var).get_id() 108 out.append(id) 109 return out
110 111
112 - def add_expression_contraction(self, expression):
113 114 str_expr = str(expression) 115 if str_expr in self.reduced_expr: 116 out, tag = self.reduced_expr[str_expr] 117 self.add_tag((tag,)) 118 return out 119 if expression == 0: 120 return 0 121 new_2 = expression.simplify() 122 if new_2 == 0: 123 return 0 124 # Add a new variable 125 tag = 'TMP%s' % len(self.reduced_expr) 126 new = Variable(tag) 127 self.reduced_expr[str_expr] = [new, tag] 128 new_2 = new_2.factorize() 129 self.reduced_expr2[tag] = new_2 130 self.add_tag((tag,)) 131 #self.unknow_fct = [] 132 #return expression 133 return new
134 135 known_fct = ['/', 'log', 'pow', 'sin', 'cos', 'asin', 'acos', 'tan', 'cot', 'acot', 136 'theta_function', 'exp']
137 - def add_function_expression(self, fct_tag, *args):
138 139 140 if not (fct_tag.startswith('cmath.') or fct_tag in self.known_fct or 141 (fct_tag, len(args)) in self.unknow_fct): 142 self.unknow_fct.append( (fct_tag, len(args)) ) 143 144 argument = [] 145 for expression in args: 146 if isinstance(expression, (MultLorentz, AddVariable, LorentzObject)): 147 try: 148 expr = expression.expand().get_rep([0]) 149 except KeyError as error: 150 if error.args != ((0,),): 151 raise 152 else: 153 raise aloha.ALOHAERROR('''Error in input format. 154 Argument of function (or denominator) should be scalar. 155 We found %s''' % expression) 156 new = expr.simplify() 157 if not isinstance(new, numbers.Number): 158 new = new.factorize() 159 argument.append(new) 160 else: 161 argument.append(expression) 162 163 for arg in argument: 164 val = re.findall(r'''\bFCT(\d*)\b''', str(arg)) 165 for v in val: 166 self.add_tag(('FCT%s' % v,)) 167 168 # check if the function is a pure numerical function. 169 if (fct_tag.startswith('cmath.') or fct_tag in self.known_fct) and \ 170 all(isinstance(x, numbers.Number) for x in argument): 171 import cmath 172 if fct_tag.startswith('cmath.'): 173 module = '' 174 else: 175 module = 'cmath.' 176 try: 177 return str(eval("%s%s(%s)" % (module,fct_tag, ','.join(repr(x) for x in argument)))) 178 except Exception as error: 179 print(error) 180 print("cmath.%s(%s)" % (fct_tag, ','.join(repr(x) for x in argument))) 181 if str(fct_tag)+str(argument) in self.inverted_fct: 182 tag = self.inverted_fct[str(fct_tag)+str(argument)] 183 v = tag.split('(')[1][:-1] 184 self.add_tag(('FCT%s' % v,)) 185 return tag 186 else: 187 id = len(self.fct_expr) 188 tag = 'FCT%s' % id 189 self.inverted_fct[str(fct_tag)+str(argument)] = 'FCT(%s)' % id 190 self.fct_expr[tag] = (fct_tag, argument) 191 self.reduced_expr2[tag] = (fct_tag, argument) 192 self.add_tag((tag,)) 193 194 return 'FCT(%s)' % id
195 196 KERNEL = Computation()
197 198 #=============================================================================== 199 # AddVariable 200 #=============================================================================== 201 -class AddVariable(list):
202 """ A list of Variable/ConstantObject/... This object represent the operation 203 between those object.""" 204 205 #variable to fastenize class recognition 206 vartype = 1 207
208 - def __init__(self, old_data=[], prefactor=1):
209 """ initialization of the object with default value """ 210 211 self.prefactor = prefactor 212 #self.tag = set() 213 list.__init__(self, old_data)
214
215 - def simplify(self):
216 """ apply rule of simplification """ 217 218 # deal with one length object 219 if len(self) == 1: 220 return self.prefactor * self[0].simplify() 221 constant = 0 222 items = {} 223 pos = -1 224 for term in self[:]: 225 pos += 1 # current position in the real self 226 if not hasattr(term, 'vartype'): 227 if isinstance(term, dict): 228 # allow term of type{(0,):x} 229 assert list(term.values()) == [0] 230 term = term[(0,)] 231 constant += term 232 del self[pos] 233 pos -= 1 234 continue 235 tag = tuple(term.sort()) 236 if tag in items: 237 orig_prefac = items[tag].prefactor # to assume to zero 0.33333 -0.3333 238 items[tag].prefactor += term.prefactor 239 if items[tag].prefactor and \ 240 abs(items[tag].prefactor) / (abs(orig_prefac)+abs(term.prefactor)) < 1e-8: 241 items[tag].prefactor = 0 242 del self[pos] 243 pos -=1 244 else: 245 items[tag] = term.__class__(term, term.prefactor) 246 self[pos] = items[tag] 247 248 # get the optimized prefactor 249 countprefact = defaultdict(int) 250 nbplus, nbminus = 0,0 251 if constant not in [0, 1,-1]: 252 countprefact[constant] += 1 253 if constant.real + constant.imag > 0: 254 nbplus += 1 255 else: 256 nbminus += 1 257 258 for var in items.values(): 259 if var.prefactor == 0: 260 self.remove(var) 261 else: 262 nb = var.prefactor 263 if nb in [1,-1]: 264 continue 265 countprefact[abs(nb)] +=1 266 if nb.real + nb.imag > 0: 267 nbplus += 1 268 else: 269 nbminus += 1 270 if countprefact and max(countprefact.values()) >1: 271 fact_prefactor = sorted(list(countprefact.items()), key=lambda x: x[1], reverse=True)[0][0] 272 else: 273 fact_prefactor = 1 274 if nbplus < nbminus: 275 fact_prefactor *= -1 276 self.prefactor *= fact_prefactor 277 278 if fact_prefactor != 1: 279 for i,a in enumerate(self): 280 try: 281 a.prefactor /= fact_prefactor 282 except AttributeError: 283 self[i] /= fact_prefactor 284 285 if constant: 286 self.append(constant/ fact_prefactor ) 287 288 # deal with one/zero length object 289 varlen = len(self) 290 if varlen == 1: 291 if hasattr(self[0], 'vartype'): 292 return self.prefactor * self[0].simplify() 293 else: 294 #self[0] is a number 295 return self.prefactor * self[0] 296 elif varlen == 0: 297 return 0 #ConstantObject() 298 return self
299
300 - def split(self, variables_id):
301 """return a dict with the key being the power associated to each variables 302 and the value being the object remaining after the suppression of all 303 the variable""" 304 305 out = defaultdict(int) 306 for obj in self: 307 for key, value in obj.split(variables_id).items(): 308 out[key] += self.prefactor * value 309 return out
310
311 - def contains(self, variables):
312 """returns true if one of the variables is in the expression""" 313 314 return any((v in obj for obj in self for v in variables ))
315 316
317 - def get_all_var_names(self):
318 319 out = [] 320 for term in self: 321 if hasattr(term, 'get_all_var_names'): 322 out += term.get_all_var_names() 323 return out
324 325 326
327 - def replace(self, id, expression):
328 """replace one object (identify by his id) by a given expression. 329 Note that expression cann't be zero. 330 Note that this should be canonical form (this should contains ONLY 331 MULTVARIABLE) --so this should be called before a factorize. 332 """ 333 new = self.__class__() 334 335 for obj in self: 336 assert isinstance(obj, MultVariable) 337 tmp = obj.replace(id, expression) 338 new += tmp 339 new.prefactor = self.prefactor 340 return new
341 342
343 - def expand(self, veto=[]):
344 """Pass from High level object to low level object""" 345 346 if not self: 347 return self 348 if self.prefactor == 1: 349 new = self[0].expand(veto) 350 else: 351 new = self.prefactor * self[0].expand(veto) 352 353 for item in self[1:]: 354 if self.prefactor == 1: 355 try: 356 new += item.expand(veto) 357 except AttributeError: 358 new = new + item 359 360 else: 361 new += (self.prefactor) * item.expand(veto) 362 return new
363
364 - def __mul__(self, obj):
365 """define the multiplication of 366 - a AddVariable with a number 367 - a AddVariable with an AddVariable 368 other type of multiplication are define via the symmetric operation base 369 on the obj class.""" 370 371 372 if not hasattr(obj, 'vartype'): # obj is a number 373 if not obj: 374 return 0 375 return self.__class__(self, self.prefactor*obj) 376 elif obj.vartype == 1: # obj is an AddVariable 377 new = self.__class__([],self.prefactor * obj.prefactor) 378 new[:] = [i*j for i in self for j in obj] 379 return new 380 else: 381 #force the program to look at obj + self 382 return NotImplemented
383
384 - def __imul__(self, obj):
385 """define the multiplication of 386 - a AddVariable with a number 387 - a AddVariable with an AddVariable 388 other type of multiplication are define via the symmetric operation base 389 on the obj class.""" 390 391 if not hasattr(obj, 'vartype'): # obj is a number 392 if not obj: 393 return 0 394 self.prefactor *= obj 395 return self 396 elif obj.vartype == 1: # obj is an AddVariable 397 new = self.__class__([], self.prefactor * obj.prefactor) 398 new[:] = [i*j for i in self for j in obj] 399 return new 400 else: 401 #force the program to look at obj + self 402 return NotImplemented
403
404 - def __neg__(self):
405 self.prefactor *= -1 406 return self
407
408 - def __add__(self, obj):
409 """Define all the different addition.""" 410 411 if not hasattr(obj, 'vartype'): 412 if not obj: # obj is zero 413 return self 414 new = self.__class__(self, self.prefactor) 415 new.append(obj/self.prefactor) 416 return new 417 elif obj.vartype == 2: # obj is a MultVariable 418 new = AddVariable(self, self.prefactor) 419 if self.prefactor == 1: 420 new.append(obj) 421 else: 422 new.append((1/self.prefactor)*obj) 423 return new 424 elif obj.vartype == 1: # obj is a AddVariable 425 new = AddVariable(self, self.prefactor) 426 for item in obj: 427 new.append(obj.prefactor/self.prefactor * item) 428 return new 429 else: 430 #force to look at obj + self 431 return NotImplemented
432
433 - def __iadd__(self, obj):
434 """Define all the different addition.""" 435 436 if not hasattr(obj, 'vartype'): 437 if not obj: # obj is zero 438 return self 439 self.append(obj/self.prefactor) 440 return self 441 elif obj.vartype == 2: # obj is a MultVariable 442 if self.prefactor == 1: 443 self.append(obj) 444 else: 445 self.append((1/self.prefactor)*obj) 446 return self 447 elif obj.vartype == 1: # obj is a AddVariable 448 for item in obj: 449 self.append(obj.prefactor/self.prefactor * item) 450 return self 451 else: 452 #force to look at obj + self 453 return NotImplemented
454
455 - def __sub__(self, obj):
456 return self + (-1) * obj
457
458 - def __rsub__(self, obj):
459 return (-1) * self + obj
460 461 __radd__ = __add__ 462 __rmul__ = __mul__ 463 464
465 - def __div__(self, obj):
466 return self.__mul__(1/obj)
467 468 __truediv__ = __div__ 469
470 - def __rdiv__(self, obj):
471 return self.__rmult__(1/obj)
472
473 - def __str__(self):
474 text = '' 475 if self.prefactor != 1: 476 text += str(self.prefactor) + ' * ' 477 text += '( ' 478 text += ' + '.join([str(item) for item in self]) 479 text += ' )' 480 return text
481
482 - def __repr__(self):
483 text = '' 484 if self.prefactor != 1: 485 text += str(self.prefactor) + ' * ' 486 text += super(AddVariable,self).__repr__() 487 return text
488
489 - def count_term(self):
490 # Count the number of appearance of each variable and find the most 491 #present one in order to factorize her 492 count = defaultdict(int) 493 correlation = defaultdict(defaultdict(int)) 494 for i,term in enumerate(self): 495 try: 496 set_term = set(term) 497 except TypeError: 498 #constant term 499 continue 500 for val1 in set_term: 501 count[val1] +=1 502 # allow to find optimized factorization for identical count 503 for val2 in set_term: 504 correlation[val1][val2] += 1 505 506 maxnb = max(count.values()) if count else 0 507 possibility = [v for v,val in count.items() if val == maxnb] 508 if maxnb == 1: 509 return 1, None 510 elif len(possibility) == 1: 511 return maxnb, possibility[0] 512 #import random 513 #return maxnb, random.sample(possibility,1)[0] 514 515 #return maxnb, possibility[0] 516 max_wgt, maxvar = 0, None 517 for var in possibility: 518 wgt = sum(w**2 for w in correlation[var].values())/len(correlation[var]) 519 if wgt > max_wgt: 520 maxvar = var 521 max_wgt = wgt 522 str_maxvar = str(KERNEL.objs[var]) 523 elif wgt == max_wgt: 524 # keep the one with the lowest string expr 525 new_str = str(KERNEL.objs[var]) 526 if new_str < str_maxvar: 527 maxvar = var 528 str_maxvar = new_str 529 return maxnb, maxvar
530
531 - def factorize(self):
532 """ try to factorize as much as possible the expression """ 533 534 max, maxvar = self.count_term() 535 if max <= 1: 536 #no factorization possible 537 return self 538 else: 539 # split in MAXVAR * NEWADD + CONSTANT 540 newadd = AddVariable() 541 constant = AddVariable() 542 #fill NEWADD and CONSTANT 543 for term in self: 544 try: 545 term.remove(maxvar) 546 except Exception: 547 constant.append(term) 548 else: 549 if len(term): 550 newadd.append(term) 551 else: 552 newadd.append(term.prefactor) 553 newadd = newadd.factorize() 554 555 # optimize the prefactor 556 if isinstance(newadd, AddVariable): 557 countprefact = defaultdict(int) 558 nbplus, nbminus = 0,0 559 for nb in [a.prefactor for a in newadd if hasattr(a, 'prefactor')]: 560 countprefact[abs(nb)] +=1 561 if nb.real + nb.imag > 0: 562 nbplus += 1 563 else: 564 nbminus += 1 565 566 newadd.prefactor = sorted(list(countprefact.items()), key=lambda x: x[1], reverse=True)[0][0] 567 if nbplus < nbminus: 568 newadd.prefactor *= -1 569 if newadd.prefactor != 1: 570 for i,a in enumerate(newadd): 571 try: 572 a.prefactor /= newadd.prefactor 573 except AttributeError: 574 newadd[i] /= newadd.prefactor 575 576 577 if len(constant) > 1: 578 constant = constant.factorize() 579 elif constant: 580 constant = constant[0] 581 else: 582 out = MultContainer([KERNEL.objs[maxvar], newadd]) 583 out.prefactor = self.prefactor 584 if newadd.prefactor != 1: 585 out.prefactor *= newadd.prefactor 586 newadd.prefactor = 1 587 return out 588 out = AddVariable([MultContainer([KERNEL.objs[maxvar], newadd]), constant], 589 self.prefactor) 590 return out
591
592 -class MultContainer(list):
593 594 vartype = 6 595
596 - def __init__(self,*args):
597 self.prefactor =1 598 list.__init__(self, *args)
599
600 - def __str__(self):
601 """ String representation """ 602 if self.prefactor !=1: 603 text = '(%s * %s)' % (self.prefactor, ' * '.join([str(t) for t in self])) 604 else: 605 text = '(%s)' % (' * '.join([str(t) for t in self])) 606 return text
607
608 - def factorize(self):
609 self[:] = [term.factorize() for term in self]
610
611 612 -class MultVariable(array):
613 """ A list of Variable with multiplication as operator between themselves. 614 Represented by array for speed optimization 615 """ 616 vartype=2 617 addclass = AddVariable 618
619 - def __new__(cls, old=[], prefactor=1):
620 return array.__new__(cls, 'i', old)
621 622
623 - def __init__(self, old=[], prefactor=1):
624 """ initialization of the object with default value """ 625 #array.__init__(self, 'i', old) <- done already in new !! 626 self.prefactor = prefactor 627 assert isinstance(self.prefactor, (float,int,int,complex))
628
629 - def get_id(self):
630 assert len(self) == 1 631 return self[0]
632
633 - def sort(self):
634 a = list(self) 635 a.sort() 636 self[:] = array('i',a) 637 return self
638
639 - def simplify(self):
640 """ simplify the product""" 641 if not len(self): 642 return self.prefactor 643 return self
644
645 - def split(self, variables_id):
646 """return a dict with the key being the power associated to each variables 647 and the value being the object remaining after the suppression of all 648 the variable""" 649 650 key = tuple([self.count(i) for i in variables_id]) 651 arg = [id for id in self if id not in variables_id] 652 self[:] = array('i', arg) 653 return SplitCoefficient([(key,self)])
654
655 - def replace(self, id, expression):
656 """replace one object (identify by his id) by a given expression. 657 Note that expression cann't be zero. 658 """ 659 assert hasattr(expression, 'vartype') , 'expression should be of type Add or Mult' 660 661 if expression.vartype == 1: # AddVariable 662 nb = self.count(id) 663 if not nb: 664 return self 665 for i in range(nb): 666 self.remove(id) 667 new = self 668 for i in range(nb): 669 new *= expression 670 return new 671 elif expression.vartype == 2: # MultLorentz 672 # be carefull about A -> A * B 673 nb = self.count(id) 674 for i in range(nb): 675 self.remove(id) 676 self.__imul__(expression) 677 return self 678 # elif expression.vartype == 0: # Variable 679 # new_id = expression.id 680 # assert new_id != id 681 # while 1: 682 # try: 683 # self.remove(id) 684 # except ValueError: 685 # break 686 # else: 687 # self.append(new_id) 688 # return self 689 else: 690 raise Exception('Cann\'t replace a Variable by %s' % type(expression))
691 692
693 - def get_all_var_names(self):
694 """return the list of variable used in this multiplication""" 695 return ['%s' % KERNEL.objs[n] for n in self]
696 697 698 699 #Defining rule of Multiplication
700 - def __mul__(self, obj):
701 """Define the multiplication with different object""" 702 703 if not hasattr(obj, 'vartype'): # should be a number 704 if obj: 705 return self.__class__(self, obj*self.prefactor) 706 else: 707 return 0 708 elif obj.vartype == 1: # obj is an AddVariable 709 new = obj.__class__([], self.prefactor*obj.prefactor) 710 old, self.prefactor = self.prefactor, 1 711 new[:] = [self * term for term in obj] 712 self.prefactor = old 713 return new 714 elif obj.vartype == 4: 715 return NotImplemented 716 717 return self.__class__(array.__add__(self, obj), self.prefactor * obj.prefactor)
718 719 __rmul__ = __mul__ 720
721 - def __imul__(self, obj):
722 """Define the multiplication with different object""" 723 724 if not hasattr(obj, 'vartype'): # should be a number 725 if obj: 726 self.prefactor *= obj 727 return self 728 else: 729 return 0 730 elif obj.vartype == 1: # obj is an AddVariable 731 new = obj.__class__([], self.prefactor * obj.prefactor) 732 self.prefactor = 1 733 new[:] = [self * term for term in obj] 734 return new 735 elif obj.vartype == 4: 736 return NotImplemented 737 738 self.prefactor *= obj.prefactor 739 return array.__iadd__(self, obj)
740
741 - def __pow__(self,value):
742 out = 1 743 for i in range(value): 744 out *= self 745 return out
746 747
748 - def __add__(self, obj):
749 """ define the adition with different object""" 750 751 if not obj: 752 return self 753 elif not hasattr(obj, 'vartype') or obj.vartype == 2: 754 new = self.addclass([self, obj]) 755 return new 756 else: 757 #call the implementation of addition implemented in obj 758 return NotImplemented
759 __radd__ = __add__ 760 __iadd__ = __add__ 761
762 - def __sub__(self, obj):
763 return self + (-1) * obj
764
765 - def __neg__(self):
766 self.prefactor *=-1 767 return self
768
769 - def __rsub__(self, obj):
770 return (-1) * self + obj
771
772 - def __idiv__(self,obj):
773 """ ONLY NUMBER DIVISION ALLOWED""" 774 assert not hasattr(obj, 'vartype') 775 self.prefactor /= obj 776 return self
777 778 __div__ = __idiv__ 779 __truediv__ = __div__ 780 781
782 - def __str__(self):
783 """ String representation """ 784 t = ['%s' % KERNEL.objs[n] for n in self] 785 if self.prefactor != 1: 786 text = '(%s * %s)' % (self.prefactor,' * '.join(t)) 787 else: 788 text = '(%s)' % (' * '.join(t)) 789 return text
790 791 __repr__ = __str__ 792
793 - def factorize(self):
794 return self
795
796 797 #=============================================================================== 798 # FactoryVar 799 #=============================================================================== 800 -class C_Variable(str):
801 vartype=0 802 type = 'complex'
803
804 -class R_Variable(str):
805 vartype=0 806 type = 'double'
807
808 -class ExtVariable(str):
809 vartype=0 810 type = 'parameter'
811
812 813 -class FactoryVar(object):
814 """This is the standard object for all the variable linked to expression. 815 """ 816 mult_class = MultVariable # The class for the multiplication 817
818 - def __new__(cls, name, baseclass, *args):
819 """Factory class return a MultVariable.""" 820 821 if name in KERNEL: 822 return cls.mult_class([KERNEL[name]]) 823 else: 824 obj = baseclass(name, *args) 825 id = KERNEL.add(name, obj) 826 obj.id = id 827 return cls.mult_class([id])
828
829 -class Variable(FactoryVar):
830
831 - def __new__(self, name, type=C_Variable):
832 return FactoryVar(name, type)
833
834 -class DVariable(FactoryVar):
835
836 - def __new__(self, name):
837 838 if aloha.complex_mass: 839 #some parameter are pass to complex 840 if name[0] in ['M','W'] or name.startswith('OM'): 841 return FactoryVar(name, C_Variable) 842 if aloha.loop_mode and name.startswith('P'): 843 return FactoryVar(name, C_Variable) 844 #Normal case: 845 return FactoryVar(name, R_Variable)
846
847 848 849 850 #=============================================================================== 851 # Object for Analytical Representation of Lorentz object (not scalar one) 852 #=============================================================================== 853 854 855 #=============================================================================== 856 # MultLorentz 857 #=============================================================================== 858 -class MultLorentz(MultVariable):
859 """Specific class for LorentzObject Multiplication""" 860 861 add_class = AddVariable # Define which class describe the addition 862
863 - def find_lorentzcontraction(self):
864 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 865 the contraction in this Multiplication.""" 866 867 out = {} 868 len_mult = len(self) 869 # Loop over the element 870 for i, fact in enumerate(self): 871 # and over the indices of this element 872 for j in range(len(fact.lorentz_ind)): 873 # in order to compare with the other element of the multiplication 874 for k in range(i+1,len_mult): 875 fact2 = self[k] 876 try: 877 l = fact2.lorentz_ind.index(fact.lorentz_ind[j]) 878 except Exception: 879 pass 880 else: 881 out[(i, j)] = (k, l) 882 out[(k, l)] = (i, j) 883 return out
884
885 - def find_spincontraction(self):
886 """return of (pos_object1, indice1) ->(pos_object2,indices2) defining 887 the contraction in this Multiplication.""" 888 889 out = {} 890 len_mult = len(self) 891 # Loop over the element 892 for i, fact in enumerate(self): 893 # and over the indices of this element 894 for j in range(len(fact.spin_ind)): 895 # in order to compare with the other element of the multiplication 896 for k in range(i+1, len_mult): 897 fact2 = self[k] 898 try: 899 l = fact2.spin_ind.index(fact.spin_ind[j]) 900 except Exception: 901 pass 902 else: 903 out[(i, j)] = (k, l) 904 out[(k, l)] = (i, j) 905 906 return out
907
908 - def neighboor(self, home):
909 """return one variable which are contracted with var and not yet expanded""" 910 911 for var in self.unused: 912 obj = KERNEL.objs[var] 913 if obj.has_component(home.lorentz_ind, home.spin_ind): 914 return obj 915 return None
916 917 918 919
920 - def expand(self, veto=[]):
921 """ expand each part of the product and combine them. 922 Try to use a smart order in order to minimize the number of uncontracted indices. 923 Veto forbids the use of sub-expression if it contains some of the variable in the 924 expression. Veto contains the id of the vetoed variables 925 """ 926 927 self.unused = self[:] # list of not expanded 928 # made in a list the interesting starting point for the computation 929 basic_end_point = [var for var in self if KERNEL.objs[var].contract_first] 930 product_term = [] #store result of intermediate chains 931 current = None # current point in the working chain 932 933 while self.unused: 934 #Loop untill we have expand everything 935 if not current: 936 # First we need to have a starting point 937 try: 938 # look in priority in basic_end_point (P/S/fermion/...) 939 current = basic_end_point.pop() 940 except Exception: 941 #take one of the remaining 942 current = self.unused.pop() 943 else: 944 #check that this one is not already use 945 if current not in self.unused: 946 current = None 947 continue 948 #remove of the unuse (usualy done in the pop) 949 self.unused.remove(current) 950 cur_obj = KERNEL.objs[current] 951 # initialize the new chain 952 product_term.append(cur_obj.expand()) 953 954 # We have a point -> find the next one 955 var_obj = self.neighboor(product_term[-1]) 956 # provide one term which is contracted with current and which is not 957 #yet expanded. 958 if var_obj: 959 product_term[-1] *= var_obj.expand() 960 cur_obj = var_obj 961 self.unused.remove(cur_obj.id) 962 continue 963 964 current = None 965 966 967 # Multiply all those current 968 # For Fermion/Vector only one can carry index. 969 out = self.prefactor 970 for fact in product_term[:]: 971 if hasattr(fact, 'vartype') and fact.lorentz_ind == fact.spin_ind == []: 972 scalar = fact.get_rep([0]) 973 if hasattr(scalar, 'vartype') and scalar.vartype == 1: 974 if not veto or not scalar.contains(veto): 975 scalar = scalar.simplify() 976 prefactor = 1 977 978 if hasattr(scalar, 'vartype') and scalar.prefactor not in [1,-1]: 979 prefactor = scalar.prefactor 980 scalar.prefactor = 1 981 new = KERNEL.add_expression_contraction(scalar) 982 fact.set_rep([0], prefactor * new) 983 out *= fact 984 return out
985
986 - def __copy__(self):
987 """ create a shadow copy """ 988 new = MultLorentz(self) 989 new.prefactor = self.prefactor 990 return new
991
992 #=============================================================================== 993 # LorentzObject 994 #=============================================================================== 995 -class LorentzObject(object):
996 """ A symbolic Object for All Helas object. All Helas Object Should 997 derivated from this class""" 998 999 contract_first = 0 1000 mult_class = MultLorentz # The class for the multiplication 1001 add_class = AddVariable # The class for the addition 1002
1003 - def __init__(self, name, lor_ind, spin_ind, tags=[]):
1004 """ initialization of the object with default value """ 1005 assert isinstance(lor_ind, list) 1006 assert isinstance(spin_ind, list) 1007 1008 self.name = name 1009 self.lorentz_ind = lor_ind 1010 self.spin_ind = spin_ind 1011 KERNEL.add_tag(set(tags))
1012
1013 - def expand(self):
1014 """Expand the content information into LorentzObjectRepresentation.""" 1015 1016 try: 1017 return self.representation 1018 except Exception: 1019 self.create_representation() 1020 return self.representation
1021
1022 - def create_representation(self):
1023 raise self.VariableError("This Object %s doesn't have define representation" % self.__class__.__name__)
1024
1025 - def has_component(self, lor_list, spin_list):
1026 """check if this Lorentz Object have some of those indices""" 1027 1028 if any([id in self.lorentz_ind for id in lor_list]) or \ 1029 any([id in self.spin_ind for id in spin_list]): 1030 return True
1031 1032 1033
1034 - def __str__(self):
1035 return '%s' % self.name
1036
1037 -class FactoryLorentz(FactoryVar):
1038 """ A symbolic Object for All Helas object. All Helas Object Should 1039 derivated from this class""" 1040 1041 mult_class = MultLorentz # The class for the multiplication 1042 object_class = LorentzObject # Define How to create the basic object. 1043
1044 - def __new__(cls, *args):
1045 name = cls.get_unique_name(*args) 1046 return FactoryVar.__new__(cls, name, cls.object_class, *args)
1047 1048 @classmethod
1049 - def get_unique_name(cls, *args):
1050 """default way to have a unique name""" 1051 return '_L_%(class)s_%(args)s' % \ 1052 {'class':cls.__name__, 1053 'args': '_'.join(args) 1054 }
1055
1056 1057 #=============================================================================== 1058 # LorentzObjectRepresentation 1059 #=============================================================================== 1060 -class LorentzObjectRepresentation(dict):
1061 """A concrete representation of the LorentzObject.""" 1062 1063 vartype = 4 # Optimization for instance recognition 1064
1065 - class LorentzObjectRepresentationError(Exception):
1066 """Specify error for LorentzObjectRepresentation"""
1067
1068 - def __init__(self, representation, lorentz_indices, spin_indices):
1069 """ initialize the lorentz object representation""" 1070 1071 self.lorentz_ind = lorentz_indices #lorentz indices 1072 self.nb_lor = len(lorentz_indices) #their number 1073 self.spin_ind = spin_indices #spin indices 1074 self.nb_spin = len(spin_indices) #their number 1075 self.nb_ind = self.nb_lor + self.nb_spin #total number of indices 1076 1077 1078 #store the representation 1079 if self.lorentz_ind or self.spin_ind: 1080 dict.__init__(self, representation) 1081 elif isinstance(representation,dict): 1082 if len(representation) == 0: 1083 self[(0,)] = 0 1084 elif len(representation) == 1 and (0,) in representation: 1085 self[(0,)] = representation[(0,)] 1086 else: 1087 raise self.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1088 else: 1089 if isinstance(representation,dict): 1090 try: 1091 self[(0,)] = representation[(0,)] 1092 except Exception: 1093 if representation: 1094 raise LorentzObjectRepresentation.LorentzObjectRepresentationError("There is no key of (0,) in representation.") 1095 else: 1096 self[(0,)] = 0 1097 else: 1098 self[(0,)] = representation
1099
1100 - def __str__(self):
1101 """ string representation """ 1102 text = 'lorentz index :' + str(self.lorentz_ind) + '\n' 1103 text += 'spin index :' + str(self.spin_ind) + '\n' 1104 #text += 'other info ' + str(self.tag) + '\n' 1105 for ind in self.listindices(): 1106 ind = tuple(ind) 1107 text += str(ind) + ' --> ' 1108 text += str(self.get_rep(ind)) + '\n' 1109 return text
1110
1111 - def get_rep(self, indices):
1112 """return the value/Variable associate to the indices""" 1113 return self[tuple(indices)]
1114
1115 - def set_rep(self, indices, value):
1116 """assign 'value' at the indices position""" 1117 1118 self[tuple(indices)] = value
1119
1120 - def listindices(self):
1121 """Return an iterator in order to be able to loop easily on all the 1122 indices of the object.""" 1123 return IndicesIterator(self.nb_ind)
1124 1125 @staticmethod
1126 - def get_mapping(l1,l2, switch_order=[]):
1127 shift = len(switch_order) 1128 for value in l1: 1129 try: 1130 index = l2.index(value) 1131 except Exception: 1132 raise LorentzObjectRepresentation.LorentzObjectRepresentationError( 1133 "Invalid addition. Object doen't have the same lorentz "+ \ 1134 "indices : %s != %s" % (l1, l2)) 1135 else: 1136 switch_order.append(shift + index) 1137 return switch_order
1138 1139
1140 - def __add__(self, obj, fact=1):
1141 1142 if not obj: 1143 return self 1144 1145 if not hasattr(obj, 'vartype'): 1146 assert self.lorentz_ind == [] 1147 assert self.spin_ind == [] 1148 new = self[(0,)] + obj * fact 1149 out = LorentzObjectRepresentation(new, [], []) 1150 return out 1151 1152 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1153 1154 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1155 # if the order of indices are different compute a mapping 1156 switch_order = [] 1157 self.get_mapping(self.lorentz_ind, obj.lorentz_ind, switch_order) 1158 self.get_mapping(self.spin_ind, obj.spin_ind, switch_order) 1159 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1160 else: 1161 # no mapping needed (define switch as identity) 1162 switch = lambda ind : (ind) 1163 1164 # Some sanity check 1165 assert tuple(self.lorentz_ind+self.spin_ind) == tuple(switch(obj.lorentz_ind+obj.spin_ind)), '%s!=%s' % (self.lorentz_ind+self.spin_ind, switch(obj.lorentz_ind+self.spin_ind)) 1166 assert tuple(self.lorentz_ind) == tuple(switch(obj.lorentz_ind)), '%s!=%s' % (tuple(self.lorentz_ind), switch(obj.lorentz_ind)) 1167 1168 # define an empty representation 1169 new = LorentzObjectRepresentation({}, obj.lorentz_ind, obj.spin_ind) 1170 1171 # loop over all indices and fullfill the new object 1172 if fact == 1: 1173 for ind in self.listindices(): 1174 value = obj.get_rep(ind) + self.get_rep(switch(ind)) 1175 new.set_rep(ind, value) 1176 else: 1177 for ind in self.listindices(): 1178 value = fact * obj.get_rep(switch(ind)) + self.get_rep(ind) 1179 new.set_rep(ind, value) 1180 1181 return new
1182
1183 - def __iadd__(self, obj, fact=1):
1184 1185 if not obj: 1186 return self 1187 1188 assert(obj.vartype == 4 == self.vartype) # are LorentzObjectRepresentation 1189 1190 if self.lorentz_ind != obj.lorentz_ind or self.spin_ind != obj.spin_ind: 1191 1192 # if the order of indices are different compute a mapping 1193 switch_order = [] 1194 self.get_mapping(obj.lorentz_ind, self.lorentz_ind, switch_order) 1195 self.get_mapping(obj.spin_ind, self.spin_ind, switch_order) 1196 switch = lambda ind : tuple([ind[switch_order[i]] for i in range(len(ind))]) 1197 else: 1198 # no mapping needed (define switch as identity) 1199 switch = lambda ind : (ind) 1200 1201 # Some sanity check 1202 assert tuple(switch(self.lorentz_ind+self.spin_ind)) == tuple(obj.lorentz_ind+obj.spin_ind), '%s!=%s' % (switch(self.lorentz_ind+self.spin_ind), (obj.lorentz_ind+obj.spin_ind)) 1203 assert tuple(switch(self.lorentz_ind) )== tuple(obj.lorentz_ind), '%s!=%s' % (switch(self.lorentz_ind), tuple(obj.lorentz_ind)) 1204 1205 # loop over all indices and fullfill the new object 1206 if fact == 1: 1207 for ind in self.listindices(): 1208 self[tuple(ind)] += obj.get_rep(switch(ind)) 1209 else: 1210 for ind in self.listindices(): 1211 self[tuple(ind)] += fact * obj.get_rep(switch(ind)) 1212 return self
1213
1214 - def __sub__(self, obj):
1215 return self.__add__(obj, fact= -1)
1216
1217 - def __rsub__(self, obj):
1218 return obj.__add__(self, fact= -1)
1219
1220 - def __isub__(self, obj):
1221 return self.__add__(obj, fact= -1)
1222
1223 - def __neg__(self):
1224 self *= -1 1225 return self
1226
1227 - def __mul__(self, obj):
1228 """multiplication performing directly the einstein/spin sommation. 1229 """ 1230 1231 if not hasattr(obj, 'vartype'): 1232 out = LorentzObjectRepresentation({}, self.lorentz_ind, self.spin_ind) 1233 for ind in out.listindices(): 1234 out.set_rep(ind, obj * self.get_rep(ind)) 1235 return out 1236 1237 # Sanity Check 1238 assert(obj.__class__ == LorentzObjectRepresentation), \ 1239 '%s is not valid class for this operation' %type(obj) 1240 1241 # compute information on the status of the index (which are contracted/ 1242 #not contracted 1243 l_ind, sum_l_ind = self.compare_indices(self.lorentz_ind, \ 1244 obj.lorentz_ind) 1245 s_ind, sum_s_ind = self.compare_indices(self.spin_ind, \ 1246 obj.spin_ind) 1247 if not(sum_l_ind or sum_s_ind): 1248 # No contraction made a tensor product 1249 return self.tensor_product(obj) 1250 1251 # elsewher made a spin contraction 1252 # create an empty representation but with correct indices 1253 new_object = LorentzObjectRepresentation({}, l_ind, s_ind) 1254 #loop and fullfill the representation 1255 for indices in new_object.listindices(): 1256 #made a dictionary (pos -> index_value) for how call the object 1257 dict_l_ind = self.pass_ind_in_dict(indices[:len(l_ind)], l_ind) 1258 dict_s_ind = self.pass_ind_in_dict(indices[len(l_ind):], s_ind) 1259 #add the new value 1260 new_object.set_rep(indices, \ 1261 self.contraction(obj, sum_l_ind, sum_s_ind, \ 1262 dict_l_ind, dict_s_ind)) 1263 1264 return new_object
1265 1266 __rmul__ = __mul__ 1267 __imul__ = __mul__ 1268
1269 - def contraction(self, obj, l_sum, s_sum, l_dict, s_dict):
1270 """ make the Lorentz/spin contraction of object self and obj. 1271 l_sum/s_sum are the position of the sum indices 1272 l_dict/s_dict are dict given the value of the fix indices (indices->value) 1273 """ 1274 out = 0 # initial value for the output 1275 len_l = len(l_sum) #store len for optimization 1276 len_s = len(s_sum) # same 1277 1278 # loop over the possibility for the sum indices and update the dictionary 1279 # (indices->value) 1280 for l_value in IndicesIterator(len_l): 1281 l_dict.update(self.pass_ind_in_dict(l_value, l_sum)) 1282 for s_value in IndicesIterator(len_s): 1283 #s_dict_final = s_dict.copy() 1284 s_dict.update(self.pass_ind_in_dict(s_value, s_sum)) 1285 1286 #return the indices in the correct order 1287 self_ind = self.combine_indices(l_dict, s_dict) 1288 obj_ind = obj.combine_indices(l_dict, s_dict) 1289 1290 # call the object 1291 factor = obj.get_rep(obj_ind) * self.get_rep(self_ind) 1292 1293 if factor: 1294 #compute the prefactor due to the lorentz contraction 1295 try: 1296 factor.prefactor *= (-1) ** (len(l_value) - l_value.count(0)) 1297 except Exception: 1298 factor *= (-1) ** (len(l_value) - l_value.count(0)) 1299 out += factor 1300 return out
1301
1302 - def tensor_product(self, obj):
1303 """ return the tensorial product of the object""" 1304 assert(obj.vartype == 4) #isinstance(obj, LorentzObjectRepresentation)) 1305 1306 new_object = LorentzObjectRepresentation({}, \ 1307 self.lorentz_ind + obj.lorentz_ind, \ 1308 self.spin_ind + obj.spin_ind) 1309 1310 #some shortcut 1311 lor1 = self.nb_lor 1312 lor2 = obj.nb_lor 1313 spin1 = self.nb_spin 1314 spin2 = obj.nb_spin 1315 1316 #define how to call build the indices first for the first object 1317 if lor1 == 0 == spin1: 1318 #special case for scalar 1319 selfind = lambda indices: [0] 1320 else: 1321 selfind = lambda indices: indices[:lor1] + \ 1322 indices[lor1 + lor2: lor1 + lor2 + spin1] 1323 1324 #then for the second 1325 if lor2 == 0 == spin2: 1326 #special case for scalar 1327 objind = lambda indices: [0] 1328 else: 1329 objind = lambda indices: indices[lor1: lor1 + lor2] + \ 1330 indices[lor1 + lor2 + spin1:] 1331 1332 # loop on the indices and assign the product 1333 for indices in new_object.listindices(): 1334 1335 fac1 = self.get_rep(tuple(selfind(indices))) 1336 fac2 = obj.get_rep(tuple(objind(indices))) 1337 new_object.set_rep(indices, fac1 * fac2) 1338 1339 return new_object
1340
1341 - def factorize(self):
1342 """Try to factorize each component""" 1343 for ind, fact in self.items(): 1344 if fact: 1345 self.set_rep(ind, fact.factorize()) 1346 1347 1348 return self
1349
1350 - def simplify(self):
1351 """Check if we can simplify the object (check for non treated Sum)""" 1352 1353 #Look for internal simplification 1354 for ind, term in self.items(): 1355 if hasattr(term, 'vartype'): 1356 self[ind] = term.simplify() 1357 #no additional simplification 1358 return self
1359 1360 @staticmethod
1361 - def compare_indices(list1, list2):
1362 """return two list, the first one contains the position of non summed 1363 index and the second one the position of summed index.""" 1364 #init object 1365 1366 # equivalent set call --slightly slower 1367 #return list(set(list1) ^ set(list2)), list(set(list1) & set(list2)) 1368 1369 1370 are_unique, are_sum = [], [] 1371 # loop over the first list and check if they are in the second list 1372 1373 for indice in list1: 1374 if indice in list2: 1375 are_sum.append(indice) 1376 else: 1377 are_unique.append(indice) 1378 # loop over the second list for additional unique item 1379 1380 for indice in list2: 1381 if indice not in are_sum: 1382 are_unique.append(indice) 1383 1384 # return value 1385 return are_unique, are_sum
1386 1387 @staticmethod
1388 - def pass_ind_in_dict(indices, key):
1389 """made a dictionary (pos -> index_value) for how call the object""" 1390 if not key: 1391 return {} 1392 out = {} 1393 for i, ind in enumerate(indices): 1394 out[key[i]] = ind 1395 return out
1396
1397 - def combine_indices(self, l_dict, s_dict):
1398 """return the indices in the correct order following the dicts rules""" 1399 1400 out = [] 1401 # First for the Lorentz indices 1402 for value in self.lorentz_ind: 1403 out.append(l_dict[value]) 1404 # Same for the spin 1405 for value in self.spin_ind: 1406 out.append(s_dict[value]) 1407 1408 return out
1409
1410 - def split(self, variables_id):
1411 """return a dict with the key being the power associated to each variables 1412 and the value being the object remaining after the suppression of all 1413 the variable""" 1414 1415 out = SplitCoefficient() 1416 zero_rep = {} 1417 for ind in self.listindices(): 1418 zero_rep[tuple(ind)] = 0 1419 1420 for ind in self.listindices(): 1421 # There is no function split if the element is just a simple number 1422 if isinstance(self.get_rep(ind), numbers.Number): 1423 if tuple([0]*len(variables_id)) in out: 1424 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind) 1425 else: 1426 out[tuple([0]*len(variables_id))] = \ 1427 LorentzObjectRepresentation(dict(zero_rep), 1428 self.lorentz_ind, self.spin_ind) 1429 out[tuple([0]*len(variables_id))][tuple(ind)] += self.get_rep(ind) 1430 continue 1431 1432 for key, value in self.get_rep(ind).split(variables_id).items(): 1433 if key in out: 1434 out[key][tuple(ind)] += value 1435 else: 1436 out[key] = LorentzObjectRepresentation(dict(zero_rep), 1437 self.lorentz_ind, self.spin_ind) 1438 out[key][tuple(ind)] += value 1439 1440 return out
1441
1442 1443 1444 1445 #=============================================================================== 1446 # IndicesIterator 1447 #=============================================================================== 1448 -class IndicesIterator:
1449 """Class needed for the iterator""" 1450
1451 - def __init__(self, len):
1452 """ create an iterator looping over the indices of a list of len "len" 1453 with each value can take value between 0 and 3 """ 1454 1455 self.len = len # number of indices 1456 if len: 1457 # initialize the position. The first position is -1 due to the method 1458 #in place which start by rising an index before returning smtg 1459 self.data = [-1] + [0] * (len - 1) 1460 else: 1461 # Special case for Scalar object 1462 self.data = 0
1463
1464 - def __iter__(self):
1465 return self
1466
1467 - def __next__(self):
1468 if not self.len: 1469 return self.nextscalar() 1470 1471 for i in range(self.len): 1472 if self.data[i] < 3: 1473 self.data[i] += 1 1474 return self.data 1475 else: 1476 self.data[i] = 0 1477 raise StopIteration
1478 #Python2 1479 next = __next__ 1480
1481 - def nextscalar(self):
1482 if self.data: 1483 raise StopIteration 1484 else: 1485 self.data = True 1486 return [0]
1487
1488 -class SplitCoefficient(dict):
1489
1490 - def __init__(self, *args, **opt):
1491 dict.__init__(self, *args, **opt) 1492 self.tag=set()
1493
1494 - def get_max_rank(self):
1495 """return the highest rank of the coefficient""" 1496 1497 return max([max(arg[:4]) for arg in self])
1498 1499 1500 if '__main__' ==__name__: 1501 1502 import cProfile
1503 - def create():
1504 for i in range(10000): 1505 LorentzObjectRepresentation.compare_indices(list(range(i%10)),[4,3,5])
1506 1507 cProfile.run('create()') 1508