1
2
3 import argparse
4 import atexit
5 import re
6 import collections
7 from string import Template
8 from copy import copy
9 from itertools import product
10 from functools import reduce
11
12 try:
13 import madgraph
14 except:
15 import internal.misc as misc
16 else:
17 import madgraph.various.misc as misc
18 import mmap
19 try:
20 from tqdm import tqdm
21 except ImportError:
22 tqdm = misc.tqdm
26 fp = open(file_path, 'r+')
27 buf = mmap.mmap(fp.fileno(),0)
28 lines = 0
29 while buf.readline():
30 lines += 1
31 return lines
32
34
36 self.graph = {}
37 self.all_wavs = []
38 self.external_wavs = []
39 self.internal_wavs = []
40
42 self.all_wavs.append(wav)
43 nature = wav.nature
44 if nature == 'external':
45 self.external_wavs.append(wav)
46 if nature == 'internal':
47 self.internal_wavs.append(wav)
48 for ext in ext_deps:
49 self.add_branch(wav, ext)
50
52 try:
53 self.graph[node_i].append(node_f)
54 except KeyError:
55 self.graph[node_i] = [node_f]
56
58 deps = [wav for wav in self.all_wavs
59 if wav.old_name == old_name and not wav.dead]
60 return deps
61
63 for wav in self.all_wavs:
64 if wav.old_name == old_name:
65 wav.dead = True
66
68 return {wav.old_name for wav in self.all_wavs}
69
71 '''Taken from https://www.python.org/doc/essays/graphs/'''
72
73 path = path + [start]
74 if start == end:
75 return path
76 if start not in self.graph:
77 return None
78 for node in self.graph[start]:
79 if node not in path:
80 newpath = self.find_path(node, end, path)
81 if newpath:
82 return newpath
83 return None
84
87
89 print_str = 'With new names:\n\t'
90 print_str += '\n\t'.join([f'{key} : {item}' for key, item in self.graph.items() ])
91 print_str += '\n\nWith old names:\n\t'
92 print_str += '\n\t'.join([f'{key.old_name} : {[i.old_name for i in item]}' for key, item in self.graph.items() ])
93 return print_str
94
98 '''Abstract class for wavefunctions and Amplitudes'''
99
100
101
102 ext_deps = None
103
104 - def __init__(self, arguments, old_name, nature):
105 self.args = arguments
106 self.old_name = old_name
107 self.nature = nature
108 self.name = None
109 self.dead = False
110 self.nb_used = 0
111 self.linkdag = []
112
116
119
120 @staticmethod
122 old_args = get_arguments(line)
123 old_name = old_args[-1].replace(' ','')
124 matches = graph.old_names() & set([old.replace(' ','') for old in old_args])
125 try:
126 matches.remove(old_name)
127 except KeyError:
128 pass
129 old_deps = old_args[0:len(matches)]
130
131
132 graph.kill_old(old_name)
133 return [graph.dependencies(dep) for dep in old_deps]
134
135 @classmethod
136 - def good_helicity(cls, wavs, graph, diag_number=None, all_hel=[], bad_hel_amp=[]):
137 exts = graph.external_wavs
138 cls.ext_deps = { i for dep in wavs for i in exts if graph.find_path(dep, i) }
139 this_comb_good = False
140 for comb in External.good_wav_combs:
141 if cls.ext_deps.issubset(set(comb)):
142 this_comb_good = True
143 break
144
145 if diag_number and this_comb_good and cls.ext_deps:
146
147 helicity = dict([(a.get_id(), a.hel) for a in cls.ext_deps])
148 this_hel = [helicity[i] for i in range(1, len(helicity)+1)]
149 hel_number = 1 + all_hel.index(tuple(this_hel))
150
151 if (hel_number,diag_number) in bad_hel_amp:
152 this_comb_good = False
153
154
155
156 return this_comb_good and cls.ext_deps
157
158 @staticmethod
160 old_args = get_arguments(line)
161 old_name = old_args[-1].replace(' ','')
162
163 this_args = copy(old_args)
164 wav_names = [w.name for w in wavs]
165 this_args[0:len(wavs)] = wav_names
166
167
168 return this_args
169
170 @staticmethod
173
174 @classmethod
175 - def get_obj(cls, line, wavs, graph, diag_num = None):
185
186
189
192
194 '''Class for storing external wavefunctions'''
195
196 good_hel = []
197 nhel_lines = ''
198 num_externals = 0
199
200 wavs_same_leg = {}
201 good_wav_combs = []
202
203 - def __init__(self, arguments, old_name):
204 super().__init__(arguments, old_name, 'external')
205 self.hel = int(self.args[2])
206 self.mg = int(arguments[0].split(',')[-1][:-1])
207 self.hel_ranges = []
208 self.raise_num()
209
210 @classmethod
213
214 @classmethod
216
217
218 old_args = get_arguments(line)
219 old_name = old_args[-1].replace(' ','')
220 graph.kill_old(old_name)
221
222 if 'NHEL' in old_args[2].upper():
223 nhel_index = re.search(r'\(.*?\)', old_args[2]).group()
224 ext_num = int(nhel_index[1:-1]) - 1
225 new_hels = sorted(list(External.hel_ranges[ext_num]), reverse=True)
226 new_hels = [int_to_string(i) for i in new_hels]
227 else:
228
229 ext_num = int(re.search(r'\(0,(\d+)\)', old_args[0]).group(1)) -1
230 new_hels = [' 0']
231
232 new_wavfuncs = []
233 for hel in new_hels:
234
235 this_args = copy(old_args)
236 this_args[2] = hel
237
238 this_wavfunc = External(this_args, old_name)
239 this_wavfunc.set_name(len(graph.external_wavs) + len(graph.internal_wavs) +1)
240
241 graph.store_wav(this_wavfunc)
242 new_wavfuncs.append(this_wavfunc)
243 if ext_num in cls.wavs_same_leg:
244 cls.wavs_same_leg[ext_num] += new_wavfuncs
245 else:
246 cls.wavs_same_leg[ext_num] = new_wavfuncs
247
248 return new_wavfuncs
249
250 @classmethod
252 num_combs = len(cls.good_hel)
253 gwc_old = [[] for x in range(num_combs)]
254 gwc=[]
255 for n, comb in enumerate(cls.good_hel):
256 sols = [[]]
257 for leg, wavs in cls.wavs_same_leg.items():
258 valid = []
259 for wav in wavs:
260 if comb[leg] == wav.hel:
261 valid.append(wav)
262 gwc_old[n].append(wav)
263 if len(valid) == 1:
264 for sol in sols:
265 sol.append(valid[0])
266 else:
267 tmp = []
268 for w in valid:
269 for sol in sols:
270 tmp2 = list(sol)
271 tmp2.append(w)
272 tmp.append(tmp2)
273 sols = tmp
274 gwc += sols
275
276 cls.good_wav_combs = gwc
277
278 @staticmethod
281
283 """ return the id of the particle under consideration """
284
285 try:
286 return self.id
287 except:
288 self.id = int(re.findall(r'P\(0,(\d+)\)', self.args[0])[0])
289 return self.id
290
294 '''Class for storing internal wavefunctions'''
295
296 max_wav_num = 0
297 num_internals = 0
298
299 @classmethod
302
303 @classmethod
305 deps = cls.get_deps(line, graph)
306 new_wavfuncs = [ cls.get_obj(line, wavs, graph)
307 for wavs in product(*deps)
308 if cls.good_helicity(wavs, graph) ]
309
310 return new_wavfuncs
311
312
313
314 @classmethod
317
318 @classmethod
324
325 - def __init__(self, arguments, old_name):
328
329
330 @staticmethod
333
335 '''Class for storing Amplitudes'''
336
337 max_amp_num = 0
338
339 - def __init__(self, arguments, old_name, diag_num):
340 self.diag_num = diag_num
341 super().__init__(arguments, old_name, 'amplitude')
342
343
344 @staticmethod
347
348 @classmethod
349 - def generate_amps(cls, line, graph, all_hel=None, all_bad_hel=[]):
350 old_args = get_arguments(line)
351 old_name = old_args[-1].replace(' ','')
352
353 amp_index = re.search(r'\(.*?\)', old_name).group()
354 diag_num = int(amp_index[1:-1])
355
356 deps = cls.get_deps(line, graph)
357
358 new_amps = [cls.get_obj(line, wavs, graph, diag_num)
359 for wavs in product(*deps)
360 if cls.good_helicity(wavs, graph, diag_num, all_hel,all_bad_hel)]
361
362 return new_amps
363
364 @classmethod
366 return Amplitude(new_args, old_name, diag_num)
367
368 @classmethod
370 wavs, graph = args
371 amp_num = -1
372 exts = graph.external_wavs
373 hel_amp = tuple([w.hel for w in sorted(cls.ext_deps, key=lambda x: x.mg)])
374 amp_num = External.map_hel[hel_amp] +1
375
376 if cls.max_amp_num < amp_num:
377 cls.max_amp_num = amp_num
378 return amp_num
379
381 '''Class for recycling helicity'''
382
383 - def __init__(self, good_elements, bad_amps=[], bad_amps_perhel=[]):
384
385 External.good_hel = []
386 External.nhel_lines = ''
387 External.num_externals = 0
388 External.wavs_same_leg = {}
389 External.good_wav_combs = []
390
391 Internal.max_wav_num = 0
392 Internal.num_internals = 0
393
394 Amplitude.max_amp_num = 0
395 self.last_category = None
396 self.good_elements = good_elements
397 self.bad_amps = bad_amps
398 self.bad_amps_perhel = bad_amps_perhel
399
400
401 self.input_file = 'matrix_orig.f'
402 self.output_file = 'matrix_orig.f'
403 self.template_file = 'template_matrix.f'
404
405 self.template_dict = {}
406
407 self.template_dict['helicity_lines'] = '\n'
408 self.template_dict['helas_calls'] = []
409 self.template_dict['jamp_lines'] = '\n'
410 self.template_dict['amp2_lines'] = '\n'
411 self.template_dict['ncomb'] = '0'
412 self.template_dict['nwavefuncs'] = '0'
413
414 self.dag = DAG()
415
416 self.diag_num = 1
417 self.got_gwc = False
418
419 self.procedure_name = self.input_file.split('.')[0].upper()
420 self.procedure_kind = 'FUNCTION'
421
422 self.old_out_name = ''
423 self.loop_var = 'K'
424
425 self.all_hel = []
426 self.hel_filt = True
427
436
438 self.output_file = file
439
442
444
445 if not 'CALL' in line:
446 return None
447
448
449 ext_calls = ['CALL OXXXXX', 'CALL IXXXXX', 'CALL VXXXXX', 'CALL SXXXXX']
450 if any( call in line for call in ext_calls ):
451 return 'external'
452
453
454
455
456 if not self.dag.external_wavs:
457 return None
458
459
460
461 matches = self.dag.old_names() & set(get_arguments(line))
462 try:
463 matches.remove(get_arguments(line)[-1])
464 except KeyError:
465 pass
466 try:
467 function = (line.split('(', 1)[0]).split()[-1]
468 except IndexError:
469 return None
470
471
472 if (function.split('_')[-1] != '0'):
473 return 'internal'
474 elif (function.split('_')[-1] == '0'):
475 return 'amplitude'
476 else:
477 print(f'Ahhhh what is going on here?\n{line}')
478 set_trace()
479
480 return None
481
482
483
485 old_pat = matchobj.group()
486 new_pat = old_pat.replace('AMP(', 'AMP( %s,' % self.loop_var)
487
488
489 return new_pat
490
492 '''Add loop_var index to amp and output variable.
493 Also update name of output variable.'''
494
495 new_line = re.sub(r'\WAMP\(.*?\)', self.add_amp_index, line)
496 new_line = re.sub(r'MATRIX\d+', 'TS(K)', new_line)
497 return new_line
498
500
501
502
503
504 return 'init_mode' in line.lower()
505
506
507
508
510 if f'{self.procedure_kind} {self.procedure_name}' in line:
511 if 'SUBROUTINE' == self.procedure_kind:
512 self.old_out_name = get_arguments(line)[-1]
513 if 'FUNCTION' == self.procedure_kind:
514 self.old_out_name = line.split('(')[0].split()[-1]
515
517
518 if 'diagram number' in line:
519 self.amp_calc_started = True
520
521 if ('AMP' not in get_arguments(line)[-1]
522 and self.amp_calc_started and list(line)[0] != 'C'):
523
524 if self.function_call(line) not in ['external',
525 'internal',
526 'amplitude']:
527 self.jamp_started = True
528 self.amp_calc_started = False
529 if self.jamp_started:
530 self.get_jamp_lines(line)
531 if self.in_amp2:
532 self.get_amp2_lines(line)
533 if self.find_amp2 and line.startswith(' ENDDO'):
534 self.in_amp2 = True
535 self.find_amp2 = False
536
538 if self.jamp_finished(line):
539 self.jamp_started = False
540 self.find_amp2 = True
541 elif not line.isspace():
542 self.template_dict['jamp_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
543
545 if line.startswith(' DO I = 1, NCOLOR'):
546 self.in_amp2 = False
547 elif not line.isspace():
548 self.template_dict['amp2_lines'] += f'{line[0:6]} {self.add_indices(line[6:])}'
549
551 self.amp_calc_started = False
552 self.jamp_started = False
553 self.find_amp2 = False
554 self.in_amp2 = False
555 self.nhel_started = False
556
558
559
560
561
562 if nature not in ['external', 'internal', 'amplitude']:
563 raise Exception('wrong unfolding')
564
565 if nature == 'external':
566 new_objs = External.generate_wavfuncs(line, self.dag)
567 for obj in new_objs:
568 obj.line = apply_args(line, [obj.args])
569 else:
570 deps = Amplitude.get_deps(line, self.dag)
571 name2dep = dict([(d.name,d) for d in sum(deps,[])])
572
573
574 if nature == 'internal':
575 new_objs = Internal.generate_wavfuncs(line, self.dag)
576 for obj in new_objs:
577 obj.line = apply_args(line, [obj.args])
578 obj.linkdag = []
579 for name in obj.args:
580 if name in name2dep:
581 name2dep[name].nb_used +=1
582 obj.linkdag.append(name2dep[name])
583
584 if nature == 'amplitude':
585 nb_diag = re.findall(r'AMP\((\d+)\)', line)[0]
586 if nb_diag not in self.bad_amps:
587 new_objs = Amplitude.generate_amps(line, self.dag, self.all_hel, self.bad_amps_perhel)
588 out_line = self.apply_amps(line, new_objs)
589 for i,obj in enumerate(new_objs):
590 if i == 0:
591 obj.line = out_line
592 obj.nb_used = 1
593 else:
594 obj.line = ''
595 obj.nb_used = 1
596 obj.linkdag = []
597 for name in obj.args:
598 if name in name2dep:
599 name2dep[name].nb_used +=1
600 obj.linkdag.append(name2dep[name])
601 else:
602 return ''
603
604
605 return new_objs
606
607
614
615 - def get_gwc(self, line, category):
616
617
618 if category not in ['external', 'internal', 'amplitude']:
619 return
620 if self.last_category != 'external':
621 self.last_category = category
622 return
623
624 External.get_gwc()
625 self.last_category = category
626
653
655 self.counter += 1
656 formatted_hel = [f'{hel}' if hel < 0 else f' {hel}' for hel in hel_comb]
657 nexternal = len(hel_comb)
658 return (f' DATA (NHEL(I,{self.counter}),I=1,{nexternal}) /{",".join(formatted_hel)}/')
659
661
662 with open(self.input_file, 'r') as input_file:
663
664 self.prepare_bools()
665
666 for line_num, line in tqdm(enumerate(input_file), total=get_num_lines(self.input_file)):
667 if line_num == 0:
668 line_cache = line
669 continue
670
671 if '!SKIP' in line:
672 continue
673
674 char_5 = ''
675 try:
676 char_5 = line[5]
677 except IndexError:
678 pass
679 if char_5 == '$':
680 line_cache = undo_multiline(line_cache, line)
681 continue
682
683 line, line_cache = line_cache, line
684
685 self.get_old_name(line)
686 self.get_good_hel(line)
687 self.get_amp_stuff(line_num, line)
688 call_type = self.function_call(line)
689 self.get_gwc(line, call_type)
690
691
692 if call_type in ['external', 'internal', 'amplitude']:
693 self.template_dict['helas_calls'] += self.unfold_helicities(
694 line, call_type)
695
696 self.template_dict['nwavefuncs'] = max(External.num_externals, Internal.max_wav_num)
697
698 for i in range(len(self.template_dict['helas_calls'])-1,-1,-1):
699 obj = self.template_dict['helas_calls'][i]
700 if obj.nb_used == 0:
701 obj.line = ''
702 for dep in obj.linkdag:
703 dep.nb_used -= 1
704
705
706
707 self.template_dict['helas_calls'] = '\n'.join([f'{obj.line.rstrip()} ! count {obj.nb_used}'
708 for obj in self.template_dict['helas_calls']
709 if obj.nb_used > 0 and obj.line])
710
712 out_file = open(self.output_file, 'w+')
713 with open(self.template_file, 'r') as file:
714 for line in file:
715 s = Template(line)
716 line = s.safe_substitute(self.template_dict)
717 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')])
718 out_file.write(line)
719 out_file.close()
720
722 out_file = open(self.output_file, 'w+')
723 self.template_dict['ncomb'] = '0'
724 self.template_dict['nwavefuncs'] = '0'
725 self.template_dict['helas_calls'] = ''
726 with open(self.template_file, 'r') as file:
727 for line in file:
728 s = Template(line)
729 line = s.safe_substitute(self.template_dict)
730 line = '\n'.join([do_multiline(sub_lines) for sub_lines in line.split('\n')])
731 out_file.write(line)
732 out_file.close()
733
734
745
748
751 '''Find the substrings separated by commas between the first
752 closed set of parentheses in 'line'.
753 '''
754 bracket_depth = 0
755 element = 0
756 arguments = ['']
757 for char in line:
758 if char == '(':
759 bracket_depth += 1
760 if bracket_depth - 1 == 0:
761
762
763 continue
764 if char == ')':
765 bracket_depth -= 1
766 if bracket_depth == 0:
767
768 break
769 if char == ',' and bracket_depth == 1:
770 element += 1
771 arguments.append('')
772 continue
773 if bracket_depth > 0:
774 arguments[element] += char
775 return arguments
776
779 function = (old_line.split('(')[0]).split()[-1]
780 old_args = old_line.split(function)[-1]
781 new_lines = [old_line.replace(old_args, f'({",".join(x)})\n')
782 for x in all_the_args]
783
784 return ''.join(new_lines)
785
787 if not new_amps:
788 return ''
789 fct = line.split('(',1)[0].split('_0')[0]
790 for i,amp in enumerate(new_amps):
791 if i == 0:
792 occur = []
793 for a in amp.args:
794 if "W(1," in a:
795 tmp = collections.defaultdict(int)
796 tmp[a] += 1
797 occur.append(tmp)
798 else:
799 for i in range(len(occur)):
800 a = amp.args[i]
801 occur[i][a] +=1
802
803
804 nb_wav = [len(o) for o in occur]
805 to_remove = nb_wav.index(max(nb_wav))
806
807 occur.pop(to_remove)
808
809 lines = []
810
811 wav_name = [o.keys() for o in occur]
812 for wfcts in product(*wav_name):
813
814 sub_amps = [amp for amp in new_amps
815 if all(w in amp.args for w in wfcts)]
816 if not sub_amps:
817 continue
818 if len(sub_amps) ==1:
819 lines.append(apply_args(line, [i.args for i in sub_amps]).replace('\n',''))
820
821 continue
822
823
824 sub_amps.sort(key=lambda a: int(a.args[-1][:-1].split(',',1)[1]))
825 windices = []
826 hel_calculated = []
827 iamp = 0
828 for i,amp in enumerate(sub_amps):
829 args = amp.args[:]
830
831 wcontract = args.pop(to_remove)
832 windex = wcontract.split(',')[1].split(')')[0]
833 windices.append(windex)
834 amp_result, args[-1] = args[-1], 'TMP(1)'
835
836 if i ==0:
837
838
839 spin = fct.split(None,1)[1][to_remove]
840 lines.append('%sP1N_%s(%s)' % (fct, to_remove+1, ', '.join(args)))
841
842 hel, iamp = re.findall('AMP\((\d+),(\d+)\)', amp_result)[0]
843 hel_calculated.append(hel)
844
845
846
847
848 if spin in "VF":
849 lines.append(""" call CombineAmp(%(nb)i,
850 & (/%(hel_list)s/),
851 & (/%(w_list)s/),
852 & TMP, W, AMP(1,%(iamp)s))""" %
853 {'nb': len(sub_amps),
854 'hel_list': ','.join(hel_calculated),
855 'w_list': ','.join(windices),
856 'iamp': iamp
857 })
858 elif spin == "S":
859 lines.append(""" call CombineAmpS(%(nb)i,
860 &(/%(hel_list)s/),
861 & (/%(w_list)s/),
862 & TMP, W, AMP(1,%(iamp)s))""" %
863 {'nb': len(sub_amps),
864 'hel_list': ','.join(hel_calculated),
865 'w_list': ','.join(windices),
866 'iamp': iamp
867 })
868 else:
869 raise Exception("split amp are not supported for spin2 and 3/2")
870
871
872 return '\n'.join(lines)
873
875 name = wav.name
876 between_brackets = re.search(r'\(.*?\)', name).group()
877 num = int(between_brackets[1:-1].split(',')[-1])
878 return num
879
881 new_line = new_line[6:]
882 old_line = old_line.replace('\n','')
883 return f'{old_line}{new_line}'
884
886 char_limit = 72
887 num_splits = len(line)//char_limit
888 if num_splits != 0 and len(line) != 72 and '!' not in line:
889 split_line = [line[i*char_limit:char_limit*(i+1)] for i in range(num_splits+1)]
890 indent = ''
891 for char in line[6:]:
892 if char == ' ':
893 indent += char
894 else:
895 break
896
897 line = f'\n ${indent}'.join(split_line)
898 return line
899
901 if i == 1:
902 return '+1'
903 if i == 0:
904 return ' 0'
905 if i == -1:
906 return '-1'
907 else:
908 print(f'How can {i} be a helicity?')
909 set_trace()
910 exit(1)
911
913 parser = argparse.ArgumentParser()
914 parser.add_argument('input_file', help='The file containing the '
915 'original matrix calculation')
916 parser.add_argument('hel_file', help='The file containing the '
917 'contributing helicities')
918 parser.add_argument('--hf-off', dest='hel_filt', action='store_false', default=True, help='Disable helicity filtering')
919 parser.add_argument('--as-off', dest='amp_splt', action='store_false', default=True, help='Disable amplitude splitting')
920
921 args = parser.parse_args()
922
923 with open(args.hel_file, 'r') as file:
924 good_elements = file.readline().split()
925
926 recycler = HelicityRecycler(good_elements)
927
928 recycler.hel_filt = args.hel_filt
929 recycler.amp_splt = args.amp_splt
930
931 recycler.set_input(args.input_file)
932 recycler.set_output('green_matrix.f')
933 recycler.set_template('template_matrix1.f')
934
935 recycler.generate_output_file()
936
937 if __name__ == '__main__':
938 main()
939