Package madgraph :: Package madevent :: Module hel_recycle
[hide private]
[frames] | no frames]

Source Code for Module madgraph.madevent.hel_recycle

  1  #!/usr/bin/env python3 
  2   
  3  import argparse 
  4  import atexit 
  5  import re 
  6  import collections 
  7  from string import Template 
  8  from copy import copy 
  9  from itertools import product 
 10  from functools import reduce  
 11   
 12  try: 
 13       import madgraph 
 14  except: 
 15       import internal.misc as misc 
 16  else: 
 17       import madgraph.various.misc as misc 
 18  import mmap 
 19  try: 
 20      from tqdm import tqdm 
 21  except ImportError: 
 22      tqdm = misc.tqdm 
23 24 25 -def get_num_lines(file_path):
26 fp = open(file_path, 'r+') 27 buf = mmap.mmap(fp.fileno(),0) 28 lines = 0 29 while buf.readline(): 30 lines += 1 31 return lines
32
33 -class DAG:
34
35 - def __init__(self):
36 self.graph = {} 37 self.all_wavs = [] 38 self.external_wavs = [] 39 self.internal_wavs = []
40
41 - def store_wav(self, wav, ext_deps=[]):
42 self.all_wavs.append(wav) 43 nature = wav.nature 44 if nature == 'external': 45 self.external_wavs.append(wav) 46 if nature == 'internal': 47 self.internal_wavs.append(wav) 48 for ext in ext_deps: 49 self.add_branch(wav, ext)
50
51 - def add_branch(self, node_i, node_f):
52 try: 53 self.graph[node_i].append(node_f) 54 except KeyError: 55 self.graph[node_i] = [node_f]
56
57 - def dependencies(self, old_name):
58 deps = [wav for wav in self.all_wavs 59 if wav.old_name == old_name and not wav.dead] 60 return deps
61
62 - def kill_old(self, old_name):
63 for wav in self.all_wavs: 64 if wav.old_name == old_name: 65 wav.dead = True
66
67 - def old_names(self):
68 return {wav.old_name for wav in self.all_wavs}
69
70 - def find_path(self, start, end, path=[]):
71 '''Taken from https://www.python.org/doc/essays/graphs/''' 72 73 path = path + [start] 74 if start == end: 75 return path 76 if start not in self.graph: 77 return None 78 for node in self.graph[start]: 79 if node not in path: 80 newpath = self.find_path(node, end, path) 81 if newpath: 82 return newpath 83 return None
84
85 - def __str__(self):
86 return self.__repr__()
87
88 - def __repr__(self):
89 print_str = 'With new names:\n\t' 90 print_str += '\n\t'.join([f'{key} : {item}' for key, item in self.graph.items() ]) 91 print_str += '\n\nWith old names:\n\t' 92 print_str += '\n\t'.join([f'{key.old_name} : {[i.old_name for i in item]}' for key, item in self.graph.items() ]) 93 return print_str
94
95 96 97 -class MathsObject:
98 '''Abstract class for wavefunctions and Amplitudes''' 99 100 # Store here which externals the last wav/amp depends on. 101 # This saves us having to call find_path multiple times. 102 ext_deps = None 103
104 - def __init__(self, arguments, old_name, nature):
105 self.args = arguments 106 self.old_name = old_name 107 self.nature = nature 108 self.name = None 109 self.dead = False 110 self.nb_used = 0 111 self.linkdag = []
112
113 - def set_name(self, *args):
114 self.args[-1] = self.format_name(*args) 115 self.name = self.args[-1]
116
117 - def format_name(self, *nums):
118 pass
119 120 @staticmethod
121 - def get_deps(line, graph):
122 old_args = get_arguments(line) 123 old_name = old_args[-1].replace(' ','') 124 matches = graph.old_names() & set([old.replace(' ','') for old in old_args]) 125 try: 126 matches.remove(old_name) 127 except KeyError: 128 pass 129 old_deps = old_args[0:len(matches)] 130 131 # If we're overwriting a wav clear it from graph 132 graph.kill_old(old_name) 133 return [graph.dependencies(dep) for dep in old_deps]
134 135 @classmethod
136 - def good_helicity(cls, wavs, graph, diag_number=None, all_hel=[], bad_hel_amp=[]):
137 exts = graph.external_wavs 138 cls.ext_deps = { i for dep in wavs for i in exts if graph.find_path(dep, i) } 139 this_comb_good = False 140 for comb in External.good_wav_combs: 141 if cls.ext_deps.issubset(set(comb)): 142 this_comb_good = True 143 break 144 145 if diag_number and this_comb_good and cls.ext_deps: 146 147 helicity = dict([(a.get_id(), a.hel) for a in cls.ext_deps]) 148 this_hel = [helicity[i] for i in range(1, len(helicity)+1)] 149 hel_number = 1 + all_hel.index(tuple(this_hel)) 150 151 if (hel_number,diag_number) in bad_hel_amp: 152 this_comb_good = False 153 154 155 156 return this_comb_good and cls.ext_deps
157 158 @staticmethod
159 - def get_new_args(line, wavs):
160 old_args = get_arguments(line) 161 old_name = old_args[-1].replace(' ','') 162 # Work out if wavs corresponds to an allowed helicity combination 163 this_args = copy(old_args) 164 wav_names = [w.name for w in wavs] 165 this_args[0:len(wavs)] = wav_names 166 # This isnt maximally efficient 167 # Could take the num from wavs that've been deleted in graph 168 return this_args
169 170 @staticmethod
171 - def get_number():
172 pass
173 174 @classmethod
175 - def get_obj(cls, line, wavs, graph, diag_num = None):
176 old_name = get_arguments(line)[-1].replace(' ','') 177 new_args = cls.get_new_args(line, wavs) 178 num = cls.get_number(wavs, graph) 179 180 this_obj = cls.call_constructor(new_args, old_name, diag_num) 181 this_obj.set_name(num, diag_num) 182 if this_obj.nature != 'amplitude': 183 graph.store_wav(this_obj, cls.ext_deps) 184 return this_obj
185 186
187 - def __str__(self):
188 return self.name
189
190 - def __repr__(self):
191 return self.name
192
193 -class External(MathsObject):
194 '''Class for storing external wavefunctions''' 195 196 good_hel = [] 197 nhel_lines = '' 198 num_externals = 0 199 # Could get this from dag but I'm worried about preserving order 200 wavs_same_leg = {} 201 good_wav_combs = [] 202
203 - def __init__(self, arguments, old_name):
204 super().__init__(arguments, old_name, 'external') 205 self.hel = int(self.args[2]) 206 self.mg = int(arguments[0].split(',')[-1][:-1]) 207 self.hel_ranges = [] 208 self.raise_num()
209 210 @classmethod
211 - def raise_num(cls):
212 cls.num_externals += 1
213 214 @classmethod
215 - def generate_wavfuncs(cls, line, graph):
216 # If graph is passed in Internal it should be done here to so 217 # we can set names 218 old_args = get_arguments(line) 219 old_name = old_args[-1].replace(' ','') 220 graph.kill_old(old_name) 221 222 if 'NHEL' in old_args[2].upper(): 223 nhel_index = re.search(r'\(.*?\)', old_args[2]).group() 224 ext_num = int(nhel_index[1:-1]) - 1 225 new_hels = sorted(list(External.hel_ranges[ext_num]), reverse=True) 226 new_hels = [int_to_string(i) for i in new_hels] 227 else: 228 # Spinor must be a scalar so give it hel = 0 229 ext_num = int(re.search(r'\(0,(\d+)\)', old_args[0]).group(1)) -1 230 new_hels = [' 0'] 231 232 new_wavfuncs = [] 233 for hel in new_hels: 234 235 this_args = copy(old_args) 236 this_args[2] = hel 237 238 this_wavfunc = External(this_args, old_name) 239 this_wavfunc.set_name(len(graph.external_wavs) + len(graph.internal_wavs) +1) 240 241 graph.store_wav(this_wavfunc) 242 new_wavfuncs.append(this_wavfunc) 243 if ext_num in cls.wavs_same_leg: 244 cls.wavs_same_leg[ext_num] += new_wavfuncs 245 else: 246 cls.wavs_same_leg[ext_num] = new_wavfuncs 247 248 return new_wavfuncs
249 250 @classmethod
251 - def get_gwc(cls):
252 num_combs = len(cls.good_hel) 253 gwc_old = [[] for x in range(num_combs)] 254 gwc=[] 255 for n, comb in enumerate(cls.good_hel): 256 sols = [[]] 257 for leg, wavs in cls.wavs_same_leg.items(): 258 valid = [] 259 for wav in wavs: 260 if comb[leg] == wav.hel: 261 valid.append(wav) 262 gwc_old[n].append(wav) 263 if len(valid) == 1: 264 for sol in sols: 265 sol.append(valid[0]) 266 else: 267 tmp = [] 268 for w in valid: 269 for sol in sols: 270 tmp2 = list(sol) 271 tmp2.append(w) 272 tmp.append(tmp2) 273 sols = tmp 274 gwc += sols 275 276 cls.good_wav_combs = gwc
277 278 @staticmethod
279 - def format_name(*nums):
280 return f'W(1,{nums[0]})'
281
282 - def get_id(self):
283 """ return the id of the particle under consideration """ 284 285 try: 286 return self.id 287 except: 288 self.id = int(re.findall(r'P\(0,(\d+)\)', self.args[0])[0]) 289 return self.id
290
291 292 293 -class Internal(MathsObject):
294 '''Class for storing internal wavefunctions''' 295 296 max_wav_num = 0 297 num_internals = 0 298 299 @classmethod
300 - def raise_num(cls):
301 cls.num_internals += 1
302 303 @classmethod
304 - def generate_wavfuncs(cls, line, graph):
305 deps = cls.get_deps(line, graph) 306 new_wavfuncs = [ cls.get_obj(line, wavs, graph) 307 for wavs in product(*deps) 308 if cls.good_helicity(wavs, graph) ] 309 310 return new_wavfuncs
311 312 313 # There must be a better way 314 @classmethod
315 - def call_constructor(cls, new_args, old_name, diag_num):
316 return Internal(new_args, old_name)
317 318 @classmethod
319 - def get_number(cls, *args):
320 num = External.num_externals + Internal.num_internals + 1 321 if cls.max_wav_num < num: 322 cls.max_wav_num = num 323 return num
324
325 - def __init__(self, arguments, old_name):
326 super().__init__(arguments, old_name, 'internal') 327 self.raise_num()
328 329 330 @staticmethod
331 - def format_name(*nums):
332 return f'W(1,{nums[0]})'
333
334 -class Amplitude(MathsObject):
335 '''Class for storing Amplitudes''' 336 337 max_amp_num = 0 338
339 - def __init__(self, arguments, old_name, diag_num):
340 self.diag_num = diag_num 341 super().__init__(arguments, old_name, 'amplitude')
342 343 344 @staticmethod
345 - def format_name(*nums):
346 return f'AMP({nums[0]},{nums[1]})'
347 348 @classmethod
349 - def generate_amps(cls, line, graph, all_hel=None, all_bad_hel=[]):
350 old_args = get_arguments(line) 351 old_name = old_args[-1].replace(' ','') 352 353 amp_index = re.search(r'\(.*?\)', old_name).group() 354 diag_num = int(amp_index[1:-1]) 355 356 deps = cls.get_deps(line, graph) 357 358 new_amps = [cls.get_obj(line, wavs, graph, diag_num) 359 for wavs in product(*deps) 360 if cls.good_helicity(wavs, graph, diag_num, all_hel,all_bad_hel)] 361 362 return new_amps
363 364 @classmethod
365 - def call_constructor(cls, new_args, old_name, diag_num):
366 return Amplitude(new_args, old_name, diag_num)
367 368 @classmethod
369 - def get_number(cls, *args):
370 wavs, graph = args 371 amp_num = -1 372 exts = graph.external_wavs 373 hel_amp = tuple([w.hel for w in sorted(cls.ext_deps, key=lambda x: x.mg)]) 374 amp_num = External.map_hel[hel_amp] +1 # Offset because Fortran counts from 1 375 376 if cls.max_amp_num < amp_num: 377 cls.max_amp_num = amp_num 378 return amp_num
379
380 -class HelicityRecycler():
381 '''Class for recycling helicity''' 382
383 - def __init__(self, good_elements, bad_amps=[], bad_amps_perhel=[]):
384 385 External.good_hel = [] 386 External.nhel_lines = '' 387 External.num_externals = 0 388 External.wavs_same_leg = {} 389 External.good_wav_combs = [] 390 391 Internal.max_wav_num = 0 392 Internal.num_internals = 0 393 394 Amplitude.max_amp_num = 0 395 self.last_category = None 396 self.good_elements = good_elements 397 self.bad_amps = bad_amps 398 self.bad_amps_perhel = bad_amps_perhel 399 400 # Default file names 401 self.input_file = 'matrix_orig.f' 402 self.output_file = 'matrix_orig.f' 403 self.template_file = 'template_matrix.f' 404 405 self.template_dict = {} 406 #initialise everything as for zero matrix element 407 self.template_dict['helicity_lines'] = '\n' 408 self.template_dict['helas_calls'] = [] 409 self.template_dict['jamp_lines'] = '\n' 410 self.template_dict['amp2_lines'] = '\n' 411 self.template_dict['ncomb'] = '0' 412 self.template_dict['nwavefuncs'] = '0' 413 414 self.dag = DAG() 415 416 self.diag_num = 1 417 self.got_gwc = False 418 419 self.procedure_name = self.input_file.split('.')[0].upper() 420 self.procedure_kind = 'FUNCTION' 421 422 self.old_out_name = '' 423 self.loop_var = 'K' 424 425 self.all_hel = [] 426 self.hel_filt = True
427
428 - def set_input(self, file):
429 if 'born_matrix' in file: 430 print('HelicityRecycler is currently ' 431 f'unable to handle {file}') 432 exit(1) 433 self.procedure_name = file.split('.')[0].upper() 434 self.procedure_kind = 'FUNCTION' 435 self.input_file = file
436
437 - def set_output(self, file):
438 self.output_file = file
439
440 - def set_template(self, file):
441 self.template_file = file
442
443 - def function_call(self, line):
444 # Check a function is called at all 445 if not 'CALL' in line: 446 return None 447 448 # Now check for external spinor 449 ext_calls = ['CALL OXXXXX', 'CALL IXXXXX', 'CALL VXXXXX', 'CALL SXXXXX'] 450 if any( call in line for call in ext_calls ): 451 return 'external' 452 453 # Now check for internal 454 # Wont find a internal when no externals have been found... 455 # ... I assume 456 if not self.dag.external_wavs: 457 return None 458 459 # Search for internals by looking for calls to the externals 460 # Maybe I should just get a list of all internals? 461 matches = self.dag.old_names() & set(get_arguments(line)) 462 try: 463 matches.remove(get_arguments(line)[-1]) 464 except KeyError: 465 pass 466 try: 467 function = (line.split('(', 1)[0]).split()[-1] 468 except IndexError: 469 return None 470 # What if [-1] is garbage? Then I'm relying on needs changing. 471 # Is that OK? 472 if (function.split('_')[-1] != '0'): 473 return 'internal' 474 elif (function.split('_')[-1] == '0'): 475 return 'amplitude' 476 else: 477 print(f'Ahhhh what is going on here?\n{line}') 478 set_trace() 479 480 return None
481 482 # string manipulation 483
484 - def add_amp_index(self, matchobj):
485 old_pat = matchobj.group() 486 new_pat = old_pat.replace('AMP(', 'AMP( %s,' % self.loop_var) 487 488 #new_pat = f'{self.loop_var},{old_pat[:-1]}{old_pat[-1]}' 489 return new_pat
490
491 - def add_indices(self, line):
492 '''Add loop_var index to amp and output variable. 493 Also update name of output variable.''' 494 # Doesnt work if the AMP arguments contain brackets 495 new_line = re.sub(r'\WAMP\(.*?\)', self.add_amp_index, line) 496 new_line = re.sub(r'MATRIX\d+', 'TS(K)', new_line) 497 return new_line
498
499 - def jamp_finished(self, line):
500 # indent_end = re.compile(fr'{self.jamp_indent}END\W') 501 # m = indent_end.match(line) 502 # if m: 503 # return True 504 return 'init_mode' in line.lower()
505 #if f'{self.old_out_name}=0.D0' in line.replace(' ', ''): 506 # return True 507 #return False 508
509 - def get_old_name(self, line):
510 if f'{self.procedure_kind} {self.procedure_name}' in line: 511 if 'SUBROUTINE' == self.procedure_kind: 512 self.old_out_name = get_arguments(line)[-1] 513 if 'FUNCTION' == self.procedure_kind: 514 self.old_out_name = line.split('(')[0].split()[-1]
515
516 - def get_amp_stuff(self, line_num, line):
517 518 if 'diagram number' in line: 519 self.amp_calc_started = True 520 # Check if the calculation of this diagram is finished 521 if ('AMP' not in get_arguments(line)[-1] 522 and self.amp_calc_started and list(line)[0] != 'C'): 523 # Check if the calculation of all diagrams is finished 524 if self.function_call(line) not in ['external', 525 'internal', 526 'amplitude']: 527 self.jamp_started = True 528 self.amp_calc_started = False 529 if self.jamp_started: 530 self.get_jamp_lines(line) 531 if self.in_amp2: 532 self.get_amp2_lines(line) 533 if self.find_amp2 and line.startswith(' ENDDO'): 534 self.in_amp2 = True 535 self.find_amp2 = False
536
537 - def get_jamp_lines(self, line):
538 if self.jamp_finished(line): 539 self.jamp_started = False 540 self.find_amp2 = True 541 elif not line.isspace(): 542 self.template_dict['jamp_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
543
544 - def get_amp2_lines(self, line):
545 if line.startswith(' DO I = 1, NCOLOR'): 546 self.in_amp2 = False 547 elif not line.isspace(): 548 self.template_dict['amp2_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
549
550 - def prepare_bools(self):
551 self.amp_calc_started = False 552 self.jamp_started = False 553 self.find_amp2 = False 554 self.in_amp2 = False 555 self.nhel_started = False
556
557 - def unfold_helicities(self, line, nature):
558 559 560 561 #print('deps',line, deps) 562 if nature not in ['external', 'internal', 'amplitude']: 563 raise Exception('wrong unfolding') 564 565 if nature == 'external': 566 new_objs = External.generate_wavfuncs(line, self.dag) 567 for obj in new_objs: 568 obj.line = apply_args(line, [obj.args]) 569 else: 570 deps = Amplitude.get_deps(line, self.dag) 571 name2dep = dict([(d.name,d) for d in sum(deps,[])]) 572 573 574 if nature == 'internal': 575 new_objs = Internal.generate_wavfuncs(line, self.dag) 576 for obj in new_objs: 577 obj.line = apply_args(line, [obj.args]) 578 obj.linkdag = [] 579 for name in obj.args: 580 if name in name2dep: 581 name2dep[name].nb_used +=1 582 obj.linkdag.append(name2dep[name]) 583 584 if nature == 'amplitude': 585 nb_diag = re.findall(r'AMP\((\d+)\)', line)[0] 586 if nb_diag not in self.bad_amps: 587 new_objs = Amplitude.generate_amps(line, self.dag, self.all_hel, self.bad_amps_perhel) 588 out_line = self.apply_amps(line, new_objs) 589 for i,obj in enumerate(new_objs): 590 if i == 0: 591 obj.line = out_line 592 obj.nb_used = 1 593 else: 594 obj.line = '' 595 obj.nb_used = 1 596 obj.linkdag = [] 597 for name in obj.args: 598 if name in name2dep: 599 name2dep[name].nb_used +=1 600 obj.linkdag.append(name2dep[name]) 601 else: 602 return '' 603 604 605 return new_objs
606 #return f'{line}\n' if nature == 'external' else line 607
608 - def apply_amps(self, line, new_objs):
609 if self.amp_splt: 610 return split_amps(line, new_objs) 611 else: 612 613 return apply_args(line, [i.args for i in new_objs])
614
615 - def get_gwc(self, line, category):
616 617 #self.last_category = 618 if category not in ['external', 'internal', 'amplitude']: 619 return 620 if self.last_category != 'external': 621 self.last_category = category 622 return 623 624 External.get_gwc() 625 self.last_category = category
626
627 - def get_good_hel(self, line):
628 if 'DATA (NHEL' in line: 629 self.nhel_started = True 630 this_hel = [int(hel) for hel in line.split('/')[1].split(',')] 631 self.all_hel.append(tuple(this_hel)) 632 elif self.nhel_started: 633 self.nhel_started = False 634 635 if self.hel_filt: 636 External.good_hel = [ self.all_hel[int(i)-1] for i in self.good_elements ] 637 else: 638 External.good_hel = self.all_hel 639 640 External.map_hel=dict([(hel,i) for i,hel in enumerate(External.good_hel)]) 641 External.hel_ranges = [set() for hel in External.good_hel[0]] 642 for comb in External.good_hel: 643 for i, hel in enumerate(comb): 644 External.hel_ranges[i].add(hel) 645 646 self.counter = 0 647 nhel_array = [self.nhel_string(hel) 648 for hel in External.good_hel] 649 nhel_lines = '\n'.join(nhel_array) 650 self.template_dict['helicity_lines'] += nhel_lines 651 652 self.template_dict['ncomb'] = len(External.good_hel)
653
654 - def nhel_string(self, hel_comb):
655 self.counter += 1 656 formatted_hel = [f'{hel}' if hel < 0 else f' {hel}' for hel in hel_comb] 657 nexternal = len(hel_comb) 658 return (f' DATA (NHEL(I,{self.counter}),I=1,{nexternal}) /{",".join(formatted_hel)}/')
659
660 - def read_orig(self):
661 662 with open(self.input_file, 'r') as input_file: 663 664 self.prepare_bools() 665 666 for line_num, line in tqdm(enumerate(input_file), total=get_num_lines(self.input_file)): 667 if line_num == 0: 668 line_cache = line 669 continue 670 671 if '!SKIP' in line: 672 continue 673 674 char_5 = '' 675 try: 676 char_5 = line[5] 677 except IndexError: 678 pass 679 if char_5 == '$': 680 line_cache = undo_multiline(line_cache, line) 681 continue 682 683 line, line_cache = line_cache, line 684 685 self.get_old_name(line) 686 self.get_good_hel(line) 687 self.get_amp_stuff(line_num, line) 688 call_type = self.function_call(line) 689 self.get_gwc(line, call_type) 690 691 692 if call_type in ['external', 'internal', 'amplitude']: 693 self.template_dict['helas_calls'] += self.unfold_helicities( 694 line, call_type) 695 696 self.template_dict['nwavefuncs'] = max(External.num_externals, Internal.max_wav_num) 697 # filter out uselless call 698 for i in range(len(self.template_dict['helas_calls'])-1,-1,-1): 699 obj = self.template_dict['helas_calls'][i] 700 if obj.nb_used == 0: 701 obj.line = '' 702 for dep in obj.linkdag: 703 dep.nb_used -= 1 704 705 706 707 self.template_dict['helas_calls'] = '\n'.join([f'{obj.line.rstrip()} ! count {obj.nb_used}' 708 for obj in self.template_dict['helas_calls'] 709 if obj.nb_used > 0 and obj.line])
710
711 - def read_template(self):
712 out_file = open(self.output_file, 'w+') 713 with open(self.template_file, 'r') as file: 714 for line in file: 715 s = Template(line) 716 line = s.safe_substitute(self.template_dict) 717 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')]) 718 out_file.write(line) 719 out_file.close()
720
722 out_file = open(self.output_file, 'w+') 723 self.template_dict['ncomb'] = '0' 724 self.template_dict['nwavefuncs'] = '0' 725 self.template_dict['helas_calls'] = '' 726 with open(self.template_file, 'r') as file: 727 for line in file: 728 s = Template(line) 729 line = s.safe_substitute(self.template_dict) 730 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')]) 731 out_file.write(line) 732 out_file.close()
733 734
735 - def generate_output_file(self):
736 if not self.good_elements: 737 misc.sprint("No helicity", self.input_file) 738 self.write_zero_matrix_element() 739 return 740 741 atexit.register(self.clean_up) 742 self.read_orig() 743 self.read_template() 744 atexit.unregister(self.clean_up)
745
746 - def clean_up(self):
747 pass
748
749 750 -def get_arguments(line):
751 '''Find the substrings separated by commas between the first 752 closed set of parentheses in 'line'. 753 ''' 754 bracket_depth = 0 755 element = 0 756 arguments = [''] 757 for char in line: 758 if char == '(': 759 bracket_depth += 1 760 if bracket_depth - 1 == 0: 761 # This is the first '('. We don't want to add it to 762 # 'arguments' 763 continue 764 if char == ')': 765 bracket_depth -= 1 766 if bracket_depth == 0: 767 # We've reached the end 768 break 769 if char == ',' and bracket_depth == 1: 770 element += 1 771 arguments.append('') 772 continue 773 if bracket_depth > 0: 774 arguments[element] += char 775 return arguments
776
777 778 -def apply_args(old_line, all_the_args):
779 function = (old_line.split('(')[0]).split()[-1] 780 old_args = old_line.split(function)[-1] 781 new_lines = [old_line.replace(old_args, f'({",".join(x)})\n') 782 for x in all_the_args] 783 784 return ''.join(new_lines)
785
786 -def split_amps(line, new_amps):
787 if not new_amps: 788 return '' 789 fct = line.split('(',1)[0].split('_0')[0] 790 for i,amp in enumerate(new_amps): 791 if i == 0: 792 occur = [] 793 for a in amp.args: 794 if "W(1," in a: 795 tmp = collections.defaultdict(int) 796 tmp[a] += 1 797 occur.append(tmp) 798 else: 799 for i in range(len(occur)): 800 a = amp.args[i] 801 occur[i][a] +=1 802 # Each element in occur is the wavs that appear in a column, with 803 # the number of occurences 804 nb_wav = [len(o) for o in occur] 805 to_remove = nb_wav.index(max(nb_wav)) 806 # Remove the one that occurs the most 807 occur.pop(to_remove) 808 809 lines = [] 810 # Get the wavs per column 811 wav_name = [o.keys() for o in occur] 812 for wfcts in product(*wav_name): 813 # Select the amplitudes produced by wfcts 814 sub_amps = [amp for amp in new_amps 815 if all(w in amp.args for w in wfcts)] 816 if not sub_amps: 817 continue 818 if len(sub_amps) ==1: 819 lines.append(apply_args(line, [i.args for i in sub_amps]).replace('\n','')) 820 821 continue 822 823 # the next line is to make the code nicer 824 sub_amps.sort(key=lambda a: int(a.args[-1][:-1].split(',',1)[1])) 825 windices = [] 826 hel_calculated = [] 827 iamp = 0 828 for i,amp in enumerate(sub_amps): 829 args = amp.args[:] 830 # Remove wav and get its index 831 wcontract = args.pop(to_remove) 832 windex = wcontract.split(',')[1].split(')')[0] 833 windices.append(windex) 834 amp_result, args[-1] = args[-1], 'TMP(1)' 835 836 if i ==0: 837 # Call the original fct with P1N_... 838 # Final arg is replaced with TMP(1) 839 spin = fct.split(None,1)[1][to_remove] 840 lines.append('%sP1N_%s(%s)' % (fct, to_remove+1, ', '.join(args))) 841 842 hel, iamp = re.findall('AMP\((\d+),(\d+)\)', amp_result)[0] 843 hel_calculated.append(hel) 844 #lines.append(' %(result)s = TMP(3) * W(3,%(w)s) + TMP(4) * W(4,%(w)s)+' 845 # % {'result': amp_result, 'w': windex}) 846 #lines.append(' & TMP(5) * W(5,%(w)s)+TMP(6) * W(6,%(w)s)' 847 # % {'result': amp_result, 'w': windex}) 848 if spin in "VF": 849 lines.append(""" call CombineAmp(%(nb)i, 850 & (/%(hel_list)s/), 851 & (/%(w_list)s/), 852 & TMP, W, AMP(1,%(iamp)s))""" % 853 {'nb': len(sub_amps), 854 'hel_list': ','.join(hel_calculated), 855 'w_list': ','.join(windices), 856 'iamp': iamp 857 }) 858 elif spin == "S": 859 lines.append(""" call CombineAmpS(%(nb)i, 860 &(/%(hel_list)s/), 861 & (/%(w_list)s/), 862 & TMP, W, AMP(1,%(iamp)s))""" % 863 {'nb': len(sub_amps), 864 'hel_list': ','.join(hel_calculated), 865 'w_list': ','.join(windices), 866 'iamp': iamp 867 }) 868 else: 869 raise Exception("split amp are not supported for spin2 and 3/2") 870 871 #lines.append('') 872 return '\n'.join(lines)
873
874 -def get_num(wav):
875 name = wav.name 876 between_brackets = re.search(r'\(.*?\)', name).group() 877 num = int(between_brackets[1:-1].split(',')[-1]) 878 return num
879
880 -def undo_multiline(old_line, new_line):
881 new_line = new_line[6:] 882 old_line = old_line.replace('\n','') 883 return f'{old_line}{new_line}'
884
885 -def do_multiline(line):
886 char_limit = 72 887 num_splits = len(line)//char_limit 888 if num_splits != 0 and len(line) != 72 and '!' not in line: 889 split_line = [line[i*char_limit:char_limit*(i+1)] for i in range(num_splits+1)] 890 indent = '' 891 for char in line[6:]: 892 if char == ' ': 893 indent += char 894 else: 895 break 896 897 line = f'\n ${indent}'.join(split_line) 898 return line
899
900 -def int_to_string(i):
901 if i == 1: 902 return '+1' 903 if i == 0: 904 return ' 0' 905 if i == -1: 906 return '-1' 907 else: 908 print(f'How can {i} be a helicity?') 909 set_trace() 910 exit(1)
911
912 -def main():
913 parser = argparse.ArgumentParser() 914 parser.add_argument('input_file', help='The file containing the ' 915 'original matrix calculation') 916 parser.add_argument('hel_file', help='The file containing the ' 917 'contributing helicities') 918 parser.add_argument('--hf-off', dest='hel_filt', action='store_false', default=True, help='Disable helicity filtering') 919 parser.add_argument('--as-off', dest='amp_splt', action='store_false', default=True, help='Disable amplitude splitting') 920 921 args = parser.parse_args() 922 923 with open(args.hel_file, 'r') as file: 924 good_elements = file.readline().split() 925 926 recycler = HelicityRecycler(good_elements) 927 928 recycler.hel_filt = args.hel_filt 929 recycler.amp_splt = args.amp_splt 930 931 recycler.set_input(args.input_file) 932 recycler.set_output('green_matrix.f') 933 recycler.set_template('template_matrix1.f') 934 935 recycler.generate_output_file()
936 937 if __name__ == '__main__': 938 main() 939