import onetick.py as otp
from onetick.py import configuration
import onetick.query as otq
[docs]class MultiOutputSource(object):
    """
    Construct a multi-output source object from several connected otp.Source objects
    This object can be saved to disk as a graph using to_otq() method, or passed to otp.run() function.
    If it's passed to otp.run(), then returned results for different outputs will be available as a dictionary.
    Parameters
    ----------
    outputs : dict
        Dictionary which keys are names of the output sources, and values are output sources themselves.
        All the passed sources should be connected.
    Examples
    --------
    Results for individual outputs can be accessed by output names
    >>> # OTdirective: skip-snippet:;
    >>> root = otp.Tick(A=1)
    >>> branch_1 = root.copy()
    >>> branch_2 = root.copy()
    >>> branch_3 = root.copy()
    >>> branch_1['B'] = 1
    >>> branch_2['B'] = 2
    >>> branch_3['B'] = 3
    >>> src = otp.MultiOutputSource(dict(BRANCH1=branch_1, BRANCH2=branch_2, BRANCH3=branch_3))
    >>> res = otp.run(src)
    >>> sorted(list(res.keys()))
    ['BRANCH1', 'BRANCH2', 'BRANCH3']
    >>> # OTdirective: skip-snippet:;
    >>> res['BRANCH1'][['A', 'B']]
       A  B
    0  1  1
    >>> # OTdirective: skip-snippet:;
    >>> res['BRANCH2'][['A', 'B']]
       A  B
    0  1  2
    >>> # OTdirective: skip-snippet:;
    >>> res['BRANCH3'][['A', 'B']]
       A  B
    0  1  3
    node_name parameter of the otp.run() method can be used to select outputs
    >>> # OTdirective: skip-snippet:;
    >>> src = otp.MultiOutputSource(dict(BRANCH1=branch_1, BRANCH2=branch_2, BRANCH3=branch_3))
    >>> res = otp.run(src, node_name=['BRANCH2', 'BRANCH3'])
    >>> sorted(list(res.keys()))
    ['BRANCH2', 'BRANCH3']
    >>> # OTdirective: skip-snippet:;
    >>> res['BRANCH2'][['A', 'B']]
       A  B
    0  1  2
    >>> # OTdirective: skip-snippet:;
    >>> res['BRANCH3'][['A', 'B']]
       A  B
    0  1  3
    If only one output is selected, then it's returned directly and not in a dictionary
    >>> # OTdirective: skip-snippet:;
    >>> src = otp.MultiOutputSource(dict(BRANCH1=branch_1, BRANCH2=branch_2, BRANCH3=branch_3))
    >>> res = otp.run(src, node_name='BRANCH2')
    >>> res[['A', 'B']]
       A  B
    0  1  2
    A dictionary with sources can also be passed to otp.run directly,
    and MultiOutputSource object will be constructed internally
    >>> res = otp.run(dict(BRANCH1=branch_1, BRANCH2=branch_2))
    >>> # OTdirective: skip-snippet:;
    >>> res['BRANCH1'][['A', 'B']]
       A  B
    0  1  1
    >>> # OTdirective: skip-snippet:;
    >>> res['BRANCH2'][['A', 'B']]
       A  B
    0  1  2
    """
    def __init__(self, outputs, main_branch_name=None):
        # 1. Checking that outputs have a common part:
        # we create a set of keys for all outputs and see if all sets are connected;
        # two sets are connected if they have any key in common
        if len(outputs) <= 1:
            raise ValueError('At least two branches should be passed to a MultiOutputSource object')
        def get_history_key_set(hist):
            keys = set()
            for rule in hist._rules:
                if "key" in rule.key_params:
                    keys.add(rule.key)
            return keys
        source_key_sets = []
        for source in outputs.values():
            source_key_sets.append(get_history_key_set(source.node()._hist))
        while len(source_key_sets) > 1:
            # we take first set from the list and add to it all the other sets that have common keys with it
            # we continue to do this until first set is the only set in the list or until it has no common keys
            # with other sets in the list
            new_key_sets = []
            first_key_set = source_key_sets[0]
            new_key_sets.append(first_key_set)
            for s in source_key_sets[1:]:
                if first_key_set.isdisjoint(s):
                    # no common keys
                    new_key_sets.append(s)
                else:
                    # there are common keys
                    first_key_set = first_key_set | s
            # checking if first_key_set had common keys with at least some other set
            if len(source_key_sets) == len(new_key_sets):
                raise ValueError("Cannot construct a MultiOutputSource object from outputs that are not connected!")
            # moving first_key_set to the end; maybe it will make things work faster
            new_key_sets = new_key_sets[1:] + [first_key_set]
            source_key_sets = new_key_sets
        # 2, 3. Assigning node names and selecting main branch
        self.__main_branch_name = None
        self.__main_branch = None
        self.__side_branches = {}
        for node_name, source in outputs.items():
            source = source.copy()
            # this is necessary to create different branches if a source is a branching point
            source.sink(otq.Passthrough())
            source.node().node_name(node_name)
            if self.__main_branch_name is None and (main_branch_name is None or main_branch_name == node_name):
                self.__main_branch_name = node_name
                self.__main_branch = source
            else:
                self.__side_branches[node_name] = source
        if self.__main_branch_name is None:
            raise ValueError(f'Branch name "{main_branch_name}" not found among passed outputs!')
        # 4, 5. Apply other branches to the main branch and copy dicts
        self.__main_branch._apply_side_branches(self.__side_branches.values())
    def _all_node_names(self):
        return [self.__main_branch_name] + list(self.__side_branches.keys())
    def _side_branch_list(self):
        return list(self.__side_branches.values())
    def _prepare_for_execution(self, symbols=None, start=None, end=None, start_time_expression=None,
                               end_time_expression=None, timezone=None, has_output=None,
                               running_query_flag=None, require_dict=False, node_name=None):
        if timezone is None:
            timezone = configuration.config.tz
        has_output = False  # to avoid sinking PASSTHROUGH to the main branch
        if node_name is None:  # if user passed a node name, we shouldn't overwrite it
            node_name = self._all_node_names()
        return self.__main_branch._prepare_for_execution(
            symbols=symbols, start=start, end=end, start_time_expression=start_time_expression,
            end_time_expression=end_time_expression, timezone=timezone, has_output=has_output,
            running_query_flag=running_query_flag, require_dict=require_dict,
            node_name=node_name
        )
[docs]    def to_otq(self, file_name=None, file_suffix=None, query_name=None, symbols=None, start=None, end=None,
               timezone=None):
        """
        Constructs an otq graph and saves it to disk
        See Also
        -------
        otp.Source.to_otq
        """
        if timezone is None:
            timezone = configuration.config.tz
        return self.__main_branch.to_otq(file_name=file_name,
                                         file_suffix=file_suffix,
                                         query_name=query_name,
                                         symbols=symbols,
                                         start=start,
                                         end=end,
                                         timezone=timezone,
                                         add_passthrough=False)