Note: We no longer publish the latest version of our code here. We primarily use a kumc-bmi github organization. The heron ETL repository, in particular, is not public. Peers in the informatics community should see MultiSiteDev for details on requesting access.

source: heron_load/devdoc/datadeps.py @ 0:42ad7288920a

heron-michigan tip
Last change on this file since 0:42ad7288920a was 0:42ad7288920a, checked in by Matt Hoag <mhoag@…>, 6 years ago

Merge with demo_concepts_3800

File size: 18.3 KB
Line 
1'''datadeps -- trace data dependencies in HERON ETL
2
3Usage:
4  $ python datadeps.py heron-etl.log --codelog >deps.csv
5  $ python datadeps.py deps.csv --convert >deps.dot
6
7.. note:: TODO: connect objects to instances (specify instance by CLI arg)
8
9:copyright: Copyright 2010-2014 University of Kansas Medical Center
10            part of the `HERON open source codebase`__;
11            see NOTICE file for license details.
12
13__ http://informatics.kumc.edu/work/wiki/HERON
14'''
15
16from StringIO import StringIO
17from collections import namedtuple
18from colorsys import hsv_to_rgb
19from csv import DictReader
20from pprint import pformat
21import ast
22import logging
23
24from pyparsing import ParseException, lineno, col
25
26from etl_log_outline import each_entry
27from select_parser import AST
28import select_parser as sql
29
30
31log = logging.getLogger(__name__)
32
33
34def main(argv, open_argv, basicConfig, write_lines, csv_writer,
35         level=logging.INFO):
36    basicConfig(level=level)
37
38    def by_ext(ext):
39        return [(fn, open_argv(fn).read())
40                for fn in argv[1:]
41                if fn.endswith(ext)]
42
43    if '--csv' in argv:
44        sqldeps = SQLScriptDeps(by_ext('.sql'))
45        taskdeps = ETLModuleScanner(by_ext('.py'))
46
47        schema = sorted(set(sqldeps.columns()) | set(taskdeps.columns()))
48        csv = csv_writer(schema)
49        csv.writerow(dict(zip(schema, schema)))
50        csv.writerows(sqldeps.rows())
51        csv.writerows(taskdeps.rows())
52    elif '--convert' in argv:
53        src = open_argv(argv[1]).read()
54        rows = list(DictReader(StringIO(src)))
55        write_lines(SQLReportDeps.as_digraph(rows))
56    elif '--codelog' in argv:
57        log_fn = argv[1]
58        sqldeps = SQLReportDeps(lambda: open_argv(log_fn))
59        schema = sqldeps.columns()
60        csv = csv_writer(schema)
61        csv.writerow(dict(zip(schema, schema)))
62        for row in sqldeps.rows():
63            csv.writerow(row)
64    else:
65        sqldeps = SQLScriptDeps(by_ext('.sql'))
66        taskdeps = ETLModuleScanner(by_ext('.py'))
67        raise NotImplementedError('use --codelog or --convert')
68        #write_lines(Dot.digraphr('deps',
69        #sqldeps.arcs() + taskdeps.arcs()))
70
71
72Dependency = namedtuple('Dependency', ('src', 'loc', 'dest'))
73
74
75class SQLDeps(object):
76    @classmethod
77    def statement_deps(cls, statement):
78        name = statement.getName()
79        table_ref = cls.table_ref
80        if name in (AST.select_statement, AST.explain_statement,
81                    AST.drop, AST.truncate, AST.alter_table, AST.create_index,
82                    AST.create_sequence, AST.commit,
83                    AST.alter_session, AST.whenever):
84            return
85        elif name in (AST.create_statement,
86                      AST.insert, AST.update, AST.delete):
87            dest = statement.to_object
88            for src in cls.find_sources(statement.as_q):
89                yield Dependency(
90                    table_ref(src), src.loc,
91                    table_ref(dest, statement.table or statement.view))
92        else:
93            log.error('not supported: %s', statement)
94            # import pdb; pdb.set_trace()
95            raise NotImplementedError
96
97    @classmethod
98    def table_ref(cls, tok, create=None):
99        return (dict(schema=tok.database, table=tok.table, create=create)
100                if tok.database
101                else dict(table=tok.table, create=create))
102
103    @classmethod
104    def find_sources(cls, toks, aliases=None):
105        '''
106        >>> def t(s):
107        ...     tree = sql.script.parseString(s)
108        ...     return [src.table for src in SQLDeps.find_sources(tree)]
109        >>> t('select x from y')
110        ['Y']
111
112        >>> t('with x as (select c from y) select * from x')
113        ['Y']
114
115        >>> t('select c from (select c from y) x')
116        ['Y']
117
118        >>> t("""select x.c from x join y on x.k = y.k""")
119        ['X', 'Y']
120
121        '''
122        if aliases is None:
123            aliases = set()
124
125        if isinstance(toks, type('')):
126            return
127        log.debug('find_sources toks parts: %s\n%s',
128                  toks.keys(), pformat(toks.items()))
129        for alias_key in (AST.t_alias, AST.query_name):
130            if alias_key in toks:
131                aliases.add(toks[alias_key])
132        if AST.from_object in toks:
133            it = toks[AST.from_object]
134            log.debug('candidate: %s', it)
135            if it.database or it.table not in aliases:
136                log.debug('BINGO: %s', it)
137                yield it
138        for tok in toks:
139            for found in cls.find_sources(tok, aliases):
140                yield found
141
142    @classmethod
143    def columns(cls):
144        class Any(object):
145            def __getitem__(self, n):
146                return None
147
148            def get(self, n, default):
149                return None
150        return cls.row(Any(), Any(), Any()).keys()
151
152    @classmethod
153    def row(cls, src, loc, dest):
154        return dict(src_table=src['table'],
155                    src_schema=src.get('schema', None),
156                    code_filename=loc['file'],
157                    code_line=loc['line'],
158                    code_column=loc['column'],
159                    dest_table=dest['table'],
160                    dest_schema=dest.get('schema', None),
161                    dest_type=dest.get('create', None))
162
163
164class SQLScriptDeps(SQLDeps):
165    def __init__(self, sources):
166        self._sources = sources
167
168    @classmethod
169    def script_deps(cls, script_name, script):
170        '''Find data dependencies in an SQL script.
171
172        :return: generator of (source, occurrence, target)
173                 where source and target are SQL object (tables, views, ...)
174                 and occurrence gives a file and line number
175
176        >>> def t(ex):
177        ...     print pformat(list(SQLScriptDeps.script_deps('ex.sql', ex)))
178
179        >>> t('create view v1 as select t.c from t')
180        ... # doctest: +NORMALIZE_WHITESPACE
181        [Dependency(src={'table': 'T', 'create': None},
182                    loc={'column': 35, 'line': 1, 'file': 'ex.sql'},
183                    dest={'table': 'V1', 'create': 'VIEW'})]
184
185        >>> t('create view s1.v1 as select t.c from s2.t')
186        ... # doctest: +NORMALIZE_WHITESPACE
187        [Dependency(src={'table': 'T', 'create': None, 'schema': 'S2'},
188                    loc={'column': 38, 'line': 1, 'file': 'ex.sql'},
189                    dest={'table': 'V1', 'create': 'VIEW', 'schema': 'S1'})]
190
191        '''
192        try:
193            logging.info('parsing script: %s', script_name)
194            parts = sql.script.parseString(script)
195        except ParseException, pe:
196            log.error('Parse error at\n%s:%s:\n%s\n%s',
197                      script_name, pe.lineno, pe.msg, pe.markInputline())
198            log.debug('call stack:', exc_info=pe)
199            return []
200
201        logging.info('analyzing script: %s', script_name)
202        return (d._replace(loc=dict(file=script_name,
203                                    line=lineno(d.loc, script),
204                                    column=col(d.loc, script)))
205                for statement in parts
206                for d in cls.statement_deps(statement))
207
208    def arcs(self):
209        def obj(x):
210            return ('%s.%s' % (x['schema'], x['table']) if 'schema' in x
211                    else x['table'])
212
213        def srcloc(loc):
214            return "%s:%d:%d" % (loc['file'], loc['line'], loc['column'])
215
216        return ((obj(src), obj(dest), srcloc(loc))
217                for fn, src in self._sources
218                for src, loc, dest in self.script_deps(fn, src))
219
220    def rows(self):
221        return (self.row(src, loc, dest)
222                for fn, src in self._sources
223                for src, loc, dest in self.script_deps(fn, src))
224
225
226class SQLReportDeps(SQLDeps):
227    def __init__(self, open_log):
228        self._open_log = open_log
229
230    @classmethod
231    def log_statements(_cls, logfp):
232        logging.info('parsing code from log entries...')
233        for entry in each_entry(logfp):
234            if not entry.detail[:1] == ['code']:
235                continue
236            detail = dict(zip(entry.detail[::2], entry.detail[1::2]))
237            # code,...,line,NN,result,...,rowcount,...,script,SSS
238            code, script, line_ = [detail[k]
239                                   for k in ['code', 'script', 'line']]
240            line = int(line_)
241
242            try:
243                for statement in sql.script.parseString(code):
244                    yield script, line, code, statement
245            except ParseException, pe:
246                log.error('Parse error at\n%s:%s:\n%s\n%s',
247                          script,
248                          pe.lineno + line - 1, pe.msg, pe.markInputline())
249                log.debug('call stack:', exc_info=pe)
250
251    def rows(self):
252        statements = self.log_statements(self._open_log())
253
254        deps = (d._replace(loc=dict(file=script,
255                                    line=ln,
256                                    column=col(d.loc, text)))
257                for script, ln, text, statement in statements
258                for d in self.statement_deps(statement))
259
260        return (self.row(src, loc, dest)
261                for src, loc, dest in deps)
262
263    @classmethod
264    def as_digraph(cls, rows,
265                   name='deps'):
266        nodeattrs = cls.rows_to_nodeattrs(rows)
267        arcdata = cls.rows_to_arcs(rows)
268        clusters = cls.rows_to_clusters(rows)
269        return Dot.digraph(name, arcdata, clusters, nodeattrs)
270
271    @classmethod
272    def rows_to_nodeattrs(cls, rows):
273        views = sorted(set(
274            [row['dest_table']
275             for row in rows
276             if 'dest_table' in row
277             and row.get('dest_type', None) == 'VIEW']))
278        return [(view, [('style', 'filled')]) for view in views]
279
280    @classmethod
281    def rows_to_arcs(cls, rows, noisy_src=('DUAL',),
282                     implicit=('might_call',)):
283        '''Extract source, dest, label strings from CSV dep data.
284
285        >>> _rows = lambda s: list(DictReader(StringIO(s)))
286
287        >>> arcs = SQLReportDeps.rows_to_arcs(_rows(T1.strip()))
288        >>> arcs
289        [('s1', [('label', 'script1:10'), ('color', '#3f527f')], 'd1')]
290
291        >>> arcs = SQLReportDeps.rows_to_arcs(_rows(T2.strip()))
292        >>> arcs
293        [('ts1', [('label', 'mod1'), ('color', '#3f527f')], 'td1')]
294
295        '''
296        def obj(x, pfx):
297            suffix = 'table' if x.get('src_table') else 'task'
298            k_obj = pfx + suffix
299            k_schema = pfx + 'schema'
300            return (('%s.%s' % (x[k_schema], x[k_obj]))
301                    if x.get(k_schema, None)
302                    else x[k_obj])
303
304        color_for = ColorPicker()
305
306        def attrs(loc):
307            label = [('label', ("%s:%d" % (loc['code_filename'],
308                                           int(loc['code_line'])))
309                      if loc.get('code_line', None)
310                      else loc['code_filename'])]
311            color = ([('color', color_for(loc['code_filename']))]
312                     if loc.get('code_filename', None)
313                     else [])
314            style = ([('style', 'dotted')]
315                     if loc.get('relationship') in implicit
316                     else [])
317            return label + color + style
318
319        return [(obj(row, 'src_'), attrs(row), obj(row, 'dest_'))
320                for row in rows
321                if obj(row, 'src_') not in noisy_src]
322
323    @classmethod
324    def rows_to_clusters(cls, rows):
325        '''
326        >>> _rows = lambda s: list(DictReader(StringIO(s)))
327
328        >>> c = SQLReportDeps.rows_to_clusters(_rows(T3.strip()))
329        >>> c  # doctest: +NORMALIZE_WHITESPACE
330        [('c0', 'schemaA',
331         ['schemaA.s1']),
332         ('c1', 'schemaB',
333         ['schemaB.d1'])]
334
335        '''
336        schema_of_table = [(row[pfx + 'schema'],
337                            '%s.%s' % (row[pfx + 'schema'],
338                                       row[pfx + 'table']))
339                           for pfx in ('src_', 'dest_')
340                           for row in rows
341                           if (pfx + 'schema') in row
342                           and (pfx + 'table') in row]
343        schemas = set([s for (s, t) in schema_of_table if s])
344        table_lists = [(s_i, [t for (s, t) in schema_of_table
345                              if s == s_i])
346                       for s_i in schemas]
347        return [('c%d' % ix, s, tables)
348                for (ix, (s, tables)) in enumerate(table_lists)]
349
350
351class ColorPicker(object):
352    '''
353    >>> cp = ColorPicker()
354    >>> c1 = cp('s1')
355    >>> c2 = cp('s2')
356    >>> c1 != c2
357    True
358    >>> c1b = cp('s1')
359    >>> c1b == c1
360    True
361    '''
362    # ack: http://martin.ankerl.com/2009/12/09/how-to-create-random-colors-programmatically/  # noqa
363    # ack: http://stackoverflow.com/a/876872
364
365    # golden ratio conjugate
366    grc = 0.618033988749895
367
368    def __init__(self):
369        self._h = 0
370        self._seen = {}
371
372    def __call__(self, s):
373        return self._seen.setdefault(s, self.next_color())
374
375    def next_color(self):
376        self._h = (self._h + self.grc) % 1.0
377        rgb = tuple([x * 255 for x in hsv_to_rgb(self._h, 0.5, 0.5)])
378        return '#%02x%02x%02x' % rgb
379
380
381T1 = '''
382src_table,dest_table,code_filename,code_line,code_column
383s1,d1,script1,10,20
384'''
385
386T2 = '''
387src_task,dest_task,code_filename,relationship
388ts1,td1,mod1,rel1
389'''
390
391
392T3 = '''
393src_table,dest_table,code_filename,code_line,code_column,src_schema,dest_schema
394s1,d1,script1,10,20,schemaA,schemaB
395'''
396
397
398class Dot(object):
399    '''
400    Warning: We assume node, arc labels are quote-safe.
401
402    >>> for line in Dot.digraph('G1', [('src1', [('a1', '1')], 'dest1')]):
403    ...     print line
404    digraph G1 {
405      node [shape=tab];
406      graph [rankdir="LR"];
407      "src1" -> "dest1" [a1="1"];
408    }
409
410    >>> Dot.attrs([('label', "x"), ('style', 'dotted')])
411    ' [label="x", style="dotted"]'
412    >>> Dot.attrs()
413    ''
414
415    '''
416
417    @classmethod
418    def digraph(cls, name, arcdata, clusters=[], nodeattrs=[],
419                rankdir='LR', node_shape='tab'):
420        '''Render arcs in graphviz dot syntax lines.
421        '''
422        return (['digraph %s {' % name,
423                 '  node [shape=%s];' % node_shape,
424                 '  graph%s;' % cls.attrs(rankdir=rankdir)] +
425                ['"%s" %s;' % (n, cls.attrs(atts))
426                 for (n, atts) in nodeattrs] +
427                [line
428                 for (cname, label, nodedata) in clusters
429                 for line in cls.cluster(cname, label, nodedata)] +
430                cls.arcs(arcdata) +
431                ['}'])
432
433    @classmethod
434    def cluster(cls, name, label, nodes):
435        return (['subgraph cluster_%(name)s {' % dict(name=name),
436                 '  label = "%s"' % label,
437                 '  color=blue'] +
438                cls.nodes(nodes) +
439                ['}'])
440
441    @classmethod
442    def nodes(cls, nodes):
443        nodelines = ['"%s";' % n for n in nodes]
444        return sorted(set(nodelines))
445
446    @classmethod
447    def arcs(cls, data):
448        arclines = ['  "%s" -> "%s"%s;' % (src, dest, cls.attrs(att))
449                    for src, att, dest in data]
450        return sorted(set(arclines))
451
452    @classmethod
453    def attrs(cls, attlist=[], **attkw):
454        inside = ', '.join(['%s="%s"' % (k, v)
455                            for (k, v) in (attlist + attkw.items())])
456        return ' [' + inside + ']' if inside else ''
457
458
459class ETLModuleScanner(ast.NodeVisitor):
460    noisy = ['dblink_id_epic', 'dblink_id_kumc', 'dblink_id_deid',
461             'dblink_deid_epic', 'dblink_id_idx',
462             'make_epic_views',
463             'deid_facts', 'deid_dimensions', 'deid_all']
464
465    def __init__(self, sources):
466        ast.NodeVisitor.__init__(self)
467        self._deps = []  # mutable state. ew.
468        self._sources = sources
469
470    def deps(self, code, fn):
471        self._deps = []
472        logging.info('parsing module: %s', fn)
473        tree = ast.parse(code, fn)
474        self.visit(tree)
475        return self._deps
476
477    def rows(self):
478        return [self.row(src, fn, dest, relationship)
479                for fn, code in self._sources
480                for src, dest, relationship in self.deps(code, fn)
481                if src not in self.noisy
482                and dest not in self.noisy]
483
484    @classmethod
485    def columns(cls):
486        class Any(object):
487            def __getattr__(self, n):
488                return None
489        return cls.row(Any(), Any(), Any(), 'needs').keys()
490
491    @classmethod
492    def row(cls, src, loc, dest, relationship):
493        return dict(src_task=src,
494                    code_filename=loc,
495                    dest_task=dest,
496                    relationship=relationship)
497
498    @classmethod
499    def is_task(cls, fundef):
500        decorators = fundef.decorator_list
501        if not decorators:
502            return False
503        return len([isinstance(d, ast.Name) and d.id == 'task'
504                    for d in decorators]) > 0
505
506    @classmethod
507    def needs(cls, task):
508        decorators = task.decorator_list
509        if not decorators:
510            return ()
511        return [(arg.s, d.func.id)  # assume @needs() args are string literals
512                for d in [d for d in decorators
513                          if isinstance(d, ast.Call)
514                          and isinstance(d.func, ast.Name)
515                          and d.func.id in ('needs', 'might_call')]
516                for arg in d.args]
517
518    @classmethod
519    def script_deps(cls, task):
520        return [elt.s
521                for expr in task.args.defaults
522                for elt in (expr.elts if isinstance(expr, ast.Tuple)
523                            else [expr] if isinstance(expr, ast.Str)
524                            else [])
525                if isinstance(elt, ast.Str) and (
526                elt.s.endswith('.sql')
527                or elt.s.endswith('.csv'))]
528
529    def visit_FunctionDef(self, node):
530        dest = node.name
531        if self.is_task(node):
532            log.debug('script deps of %s: %s', dest, self.script_deps(node))
533            for src, relationship in self.needs(node):
534                self._deps.append((src, dest, relationship))
535
536
537if __name__ == '__main__':
538    def _bigstack():
539        import sys
540        sys.setrecursionlimit(10000)
541
542    def _with_caps():
543        from sys import argv, stdout
544        from csv import DictWriter
545
546        def open_argv(path):
547            ix = argv.index(path)
548            return open(argv[ix])
549
550        def write_lines(lines):
551            for line in lines:
552                stdout.write(line + '\n')
553
554        def csv_writer(fieldnames):
555            return DictWriter(stdout, fieldnames)
556
557        main(argv=argv,
558             open_argv=open_argv,
559             write_lines=write_lines,
560             csv_writer=csv_writer,
561             basicConfig=logging.basicConfig)
562
563    _bigstack()
564    _with_caps()
Note: See TracBrowser for help on using the repository browser.