import inspect
import datetime
import warnings
from typing import Union, List, Optional, Dict, Any, Callable
from collections import defaultdict
import onetick.query as otq
import pandas as pd
import pyomd
from onetick import py as otp
from onetick.py import utils, configuration
from onetick.py.core.column_operations.base import _Operation
from onetick.py.types import time2nsectime
from onetick.py.core.source import _is_dict_required
from onetick.lib.instance import OneTickLib
from onetick.py.compatibility import has_max_expected_ticks_per_symbol, has_password_param
from onetick.py._stack_info import _add_stack_info_to_exception
[docs]def run(query: Union[Callable, Dict, otp.Source, otp.MultiOutputSource, otp.query, str, otq.EpBase, otq.GraphQuery,
                     otq.ChainQuery, otq.Chainlet],
        *,
        symbols: Union[List[Union[str, otq.Symbol]], otp.Source, str, None] = None,
        start: Union[datetime.datetime, otp.datetime, pyomd.timeval_t, None] = utils.adaptive,
        end: Union[datetime.datetime, otp.datetime, pyomd.timeval_t, None] = utils.adaptive,
        date: Union[datetime.date, otp.date, None] = None,
        start_time_expression: Optional[str] = None,
        end_time_expression: Optional[str] = None,
        timezone=utils.default,  # type: ignore
        context=utils.default,  # type: ignore
        username: Optional[str] = None,
        alternative_username: Optional[str] = None,
        password: Optional[str] = None,
        batch_size: Optional[int] = utils.default,
        running: Optional[bool] = False,
        query_properties: Optional[pyomd.QueryProperties] = None,
        concurrency: Optional[int] = utils.default,
        apply_times_daily: Optional[int] = None,
        symbol_date: Union[datetime.datetime, int, None] = None,
        query_params: Optional[Dict[str, Any]] = None,
        time_as_nsec: bool = True,
        treat_byte_arrays_as_strings: bool = True,
        output_matrix_per_field: bool = False,
        output_structure: Optional[str] = None,
        return_utc_times: Optional[bool] = None,
        connection=None,
        callback=None,
        svg_path=None,
        use_connection_pool: bool = False,
        node_name: Union[str, List[str], None] = None,
        require_dict: bool = False,
        max_expected_ticks_per_symbol: Optional[int] = None):
    """
    Executes a query and returns its result.
    Parameters
    ----------
    query: :py:class:`onetick.py.Source`, otq.Ep, otq.Graph, otq.GraphQuery, otq.ChainQuery, str, otq.Chainlet
        Query to execute can be source, path of the query on a disk or onetick.query graph or event processor.
        For running OTQ files, it represents the path (including filename) to the OTQ file to run a single query within
        the file. If more than one query is present, then the query to be run must be specified
        (that is, ``'path_to_file/otq_file.otq::query_to_run'``).
    symbols: str, list of str, list of otq.Symbol, :py:class:`onetick.py.Source`, pd.DataFrame, optional
        Symbol(s) to run the query for passed as a string, a list of strings, a pd.DataFrame with the ``SYMBOL_NAME``
        column, or as a "symbols" query which results include the ``SYMBOL_NAME`` column. The start/end times for the
        symbols query will taken from the params below. See :ref:`symbols <Symbols>` for more details.
    start: datetime.datetime, :py:class:`onetick.py.datetime`, :py:class:`pyomd.timeval_t`, optional
        The start time of the query. If datetime.datetime was passed then timezone of object is ignored by Onetick,
        therefore we suggest using only :py:class:`otp.datetime <onetick.py.datetime>` objects as an argument.
        onetick.py uses otp.config.default_start_time as default value,
        if you don't want to specify start time, e.g. to use saved time of the query,
        then you should specify None value.
        See also ``timezone`` argument.
    end: datetime.datetime, :py:class:`onetick.py.datetime`, :py:class:`pyomd.timeval_t`, optional
        The end time of the query. If datetime.datetime was passed then timezone of object is ignored by Onetick,
        therefore we suggest using only :py:class:`otp.datetime <onetick.py.datetime>` objects as an argument.
        See also ``timezone`` argument.
        onetick.py uses otp.config.default_end_time as default value,
        if you don't want to specify end time, e.g. to use saved time of the query,
        then you should specify None value.
    date: datetime.date, :py:class:`onetick.py.date`, optional
        The date to run the query for. Can be set instead of ``start`` and ``end`` parameters.
        If set then the interval to run the query will be from 0:00 to 24:00 of the specified date.
    start_time_expression: str, optional
        Start time onetick expression of the query. If specified, it will take precedence over ``start``.
        Supported only if query is Source, Graph or Event Processor.
    end_time_expression: str, optional
        End time onetick expression of the query. If specified, it will take precedence over ``end``.
        Supported only if query is Source, Graph or Event Processor.
    timezone: str, optional
         The timezone of start and end times, as well as of the output timestamps. It has higher priority then timezone
         of start and end parameters. If parameter is omitted timestamps of ticks will be formatted with
         the default timezone.
    context: str (defaults to otp.config.default_context), optional
        Allows specification of different instances of OneTick tick_servers to connect to
    username
        The username to make the connection.
        By default the user which executed the process is used.
    alternative_username: str
        The username used for authentication.
        Needs to be set only when the tick server is configured to use password-based authentication.
        By default, ``otp.config.default_auth_username`` is used.
    password: str, optional
        The password used for authentication.
        Needs to be set only when the tick server is configured to use password-based authentication.
        Note: not supported and ignored on older OneTick versions.
        By default, ``otp.config.default_password`` is used.
    batch_size: int
        number of symbols to run in one batch.
        By default, the value from otp.config.default_batch_size is used.
    running: bool, optional
        Indicates whether a query is CEP or not. Default is `False`.
    query_properties: :py:class:`pyomd.QueryProperties`, optional
       Query properties, such as ONE_TO_MANY_POLICY, ALLOW_GRAPH_REUSE, etc
    concurrency: int, optional
        The maximum number of CPU cores to use to process the query.
        By default, the value from otp.config.default_concurrency is used.
    apply_times_daily: bool
        Runs the query for every day in the ``start``-``end`` time range,
        using the time components of ``start`` and ``end`` datetimes.
        Note that those daily intervals are executed separately, so you don't have access
        to the data from previous or next days (see example in the next section).
    symbol_date:
        The symbol date used to look up symbology mapping information in the reference database,
        expressed as datetime object or integer of YYYYMMDD format
    query_params: dict
        Parameters of the query.
    time_as_nsec: bool
        Outputs timestamps up to nanoseconds granularity
        (defaults to False: by default we output timestamps in microseconds granularity)
    treat_byte_arrays_as_strings: bool
        Outputs byte arrays as strings (defaults to True)
    output_matrix_per_field: bool
        Changes output format to list of matrices per field.
    output_structure: otp.Source.OutputStructure, optional
        Structure (type) of the result. Supported values are:
          - `df` (default) - the result is returned as pandas.DataFrame
            or dict[symbol: pandas.Dataframe] in case of using multiple symbols or first stage query.
          - `map` - the result is returned as SymbolNumpyResultMap.
          - `list` - the result is returned as list.
    return_utc_times: bool
        If True Return times in UTC timezone and in local timezone otherwise
    connection: :py:class:`pyomd.Connection`
        The connection to be used for discovering nested .otq files
    callback: :py:class:`onetick.py.CallbackBase`
         Class with callback methods.
         If set, the output of the query should be controlled with callbacks
         and this function returns nothing.
    svg_path
    use_connection_pool
    node_name: str, List[str], optional
        Name of the output node to select result from. If query graph has several output nodes, you can specify the name
        of the node to choose result from. If node_name was specified, query should be presented by path on the disk
        and output_structure should be `df`
    require_dict: bool
        If set to True, result will be forced to be a dictionary even if it's returned for a single symbol
    max_expected_ticks_per_symbol: int
        Expected maximum number of ticks per symbol (used for performance optimizations).
        By default, ``otp.config.max_expected_ticks_per_symbol`` is used.
    Note
    ----
    It is possible to log currently executed symbol. For that `otp.config.log_symbol` should be set to `True`
    (it can be set via `OTP_LOG_SYMBOL` env var). Note, in this case otp.run does not produce the output so
    it should be used only for debugging purposes.
    Returns
    -------
    result, list, dict, :pandas:`pandas.DataFrame`, None
        result of the query
    Examples
    --------
    Running :py:class:`onetick.py.Source` and setting start and end times:
    >>> data = otp.Tick(A=1)
    >>> otp.run(data, start=otp.dt(2003, 12, 2), end=otp.dt(2003, 12, 4))
            Time  A
    0 2003-12-02  1
    Setting query interval with ``date`` parameter:
    >>> data = otp.Tick(A=1)
    >>> data['START'] = data['_START_TIME']
    >>> data['END'] = data['_END_TIME']
    >>> otp.run(data, date=otp.dt(2003, 12, 1))
            Time  A      START        END
    0 2003-12-01  1 2003-12-01 2003-12-02
    Running otq.Ep and passing query parameters:
    >>> ep = otq.TickGenerator(bucket_interval=0, fields='long A = $X').tick_type('TT')
    >>> otp.run(ep, symbols='LOCAL::', query_params={'X': 1})
            Time  A
    0 2003-12-04  1
    Running in callback mode:
    >>> class Callback(otp.CallbackBase):
    ...     def __init__(self):
    ...         self.result = None
    ...     def process_tick(self, tick, time):
    ...         self.result = tick
    >>> data = otp.Tick(A=1)
    >>> callback = Callback()
    >>> otp.run(data, callback=callback)
    >>> callback.result
    {'A': 1}
    Running with ``apply_times_daily``.
    Note that daily intervals are processed separately so, for example,
    we can't access column **COUNT** from previous day.
    >>> trd = otp.DataSource('NYSE_TAQ', symbols='AAPL', tick_type='TRD')  # doctest: +SKIP
    >>> trd = trd.agg({'COUNT': otp.agg.count()},
    ...               bucket_interval=12 * 3600, bucket_time='start')  # doctest: +SKIP
    >>> trd['PREV_COUNT'] = trd['COUNT'][-1]  # doctest: +SKIP
    >>> otp.run(trd, apply_times_daily=True,
    ...         start=otp.dt(2023, 4, 3), end=otp.dt(2023, 4, 5), timezone='EST5EDT')  # doctest: +SKIP
                     Time   COUNT  PREV_COUNT
    0 2023-04-03 00:00:00  328447           0
    1 2023-04-03 12:00:00  240244      328447
    2 2023-04-04 00:00:00  263293           0
    3 2023-04-04 12:00:00  193018      263293
    """
    _ = OneTickLib()
    if timezone is utils.default:
        timezone = configuration.config.tz
    if context is utils.default:
        context = configuration.config.context
    if concurrency is utils.default:
        concurrency = configuration.config.default_concurrency
    if batch_size is utils.default:
        batch_size = configuration.config.default_batch_size
    if query_properties is None:
        query_properties = pyomd.QueryProperties()
    str_qp = query_properties.convert_to_name_value_pairs_string().c_str()
    if not next(filter(lambda k: k == 'USE_FT', map(lambda pair: pair.split('=')[0], str_qp.split(','))), False):
        query_properties.set_property_value('USE_FT', otp.config.default_fault_tolerance)
    if date is not None:
        for v in (start, end, start_time_expression, end_time_expression):
            if v is not None and v is not utils.adaptive:
                raise ValueError("Can't use 'date' parameter when other time interval parameters are specified")
        start = otp.date(date)
        end = start + otp.Day(1)
    if isinstance(start, _Operation) and start_time_expression is None:
        start_time_expression = str(start)
        start = utils.adaptive
    if isinstance(end, _Operation) and end_time_expression is None:
        end_time_expression = str(end)
        end = utils.adaptive
    if inspect.ismethod(query) or inspect.isfunction(query):
        t_s = None
        if isinstance(symbols, otp.Source):
            t_s = symbols
        if isinstance(symbols, otp.query):
            t_s = otp.Query(symbols)
        if isinstance(symbols, str):
            t_s = otp.Tick(SYMBOL_NAME=symbols)
        if isinstance(symbols, list):
            t_s = otp.Ticks(SYMBOL_NAME=symbols)
        if isinstance(t_s, otp.Source):
            query = query(t_s.to_symbol_param())  # type: ignore
    query, query_params = _preprocess_otp_query(query, query_params)
    # If query is an otp.Source object, then it can deal with otp.datetime and pd.Timestamp types
    if callback is None and otp.config.log_symbol:
        callback = LogCallback(query)
    output_mode = otq.QueryOutputMode.numpy
    if callback is not None:
        output_mode = otq.QueryOutputMode.callback
    output_structure, output_structure_for_otq = _process_output_structure(output_structure)
    if symbol_date:
        # otq.run supports only strings and datetime.date
        symbol_date = otp.date(symbol_date).to_str()
    require_dict = require_dict or _is_dict_required(symbols)
    # converting symbols properly
    if isinstance(symbols, otp.Source):
        # check if SYMBOL_NAME is in schema, or if schema contains only one field
        if ('SYMBOL_NAME' not in symbols.columns(skip_meta_fields=True).keys()) and \
                
len(symbols.columns(skip_meta_fields=True)) != 1:
            warnings.warn('Using as a symbol list a source without "SYMBOL_NAME" field '
                          'and with more than one field! This won\'t work unless the schema is incomplete')
        symbols = symbols._convert_symbol_to_string(
            symbol=symbols,
            tmp_otq=query._tmp_otq if isinstance(query, otp.Source) else None,
            start=start,
            end=end,
            timezone=timezone
        )
    if isinstance(symbols, str):
        symbols = [symbols]
    if isinstance(symbols, pd.DataFrame):
        symbols = utils.get_symbol_list_from_df(symbols)
    if isinstance(query, dict):
        # we assume it's a dictionary of sources for the MultiOutputSource object
        query = otp.MultiOutputSource(query)
    if isinstance(query, otp.Source) or isinstance(query, otp.MultiOutputSource):
        start = None if start is utils.adaptive else start
        end = None if end is utils.adaptive else end
        start, end = _get_start_end(start, end, timezone, use_pyomd_timeval=False)  # TODO: undstnd why nsec not supptd
        param_upd = query._prepare_for_execution(symbols=symbols, start=start, end=end,
                                                 timezone=timezone,
                                                 start_time_expression=start_time_expression,
                                                 end_time_expression=end_time_expression,
                                                 require_dict=require_dict,
                                                 running_query_flag=running,
                                                 node_name=node_name, has_output=None)
        for key, value in param_upd.items():
            # here we want to make sure we substituted all params from the passed dict,
            # so we raise an error if an unknown parameter is passed in the dict
            if key == 'query': query = value  # noqa: E701
            elif key == 'symbols': symbols = value  # noqa: E701
            elif key == 'start': start = value  # noqa: E701
            elif key == 'end': end = value  # noqa: E701
            elif key == 'start_time_expression': start_time_expression = value  # noqa: E701
            elif key == 'end_time_expression': end_time_expression = value  # noqa: E701
            elif key == 'require_dict': require_dict = value  # noqa: E701
            elif key == 'node_name': node_name = value  # noqa: E701
            elif key == 'time_as_nsec': time_as_nsec = value  # noqa: E701
            else: raise ValueError('Unknown parameter returned!')  # noqa: E701
    elif isinstance(query, (otq.graph_components.EpBase, otq.Chainlet)):
        query = otq.Graph(query)
    start, end = _get_start_end(start, end, timezone)
    # if file name is not in single quotes, then put it in single quotes
    if isinstance(query, str):
        if not query[0] == "'" and not query[-1] == "'":
            # callback mode doesn't like single quotes
            if output_mode != otq.QueryOutputMode.callback:
                query = f"'{query}'"
    # authentication
    alternative_username = alternative_username or otp.config.default_auth_username
    password = password or otp.config.default_password
    kwargs = {}
    if password is not None:
        if has_password_param(throw_warning=True):
            kwargs['password'] = password
    max_expected_ticks_per_symbol = max_expected_ticks_per_symbol or otp.config.max_expected_ticks_per_symbol
    if has_max_expected_ticks_per_symbol(throw_warning=True):
        kwargs['max_expected_ticks_per_symbol'] = max_expected_ticks_per_symbol
    try:
        result = otq.run(query, symbols=symbols, start=start, end=end, context=context, username=username,
                         timezone=timezone, start_time_expression=start_time_expression,
                         end_time_expression=end_time_expression,
                         alternative_username=alternative_username, batch_size=batch_size,
                         running_query_flag=running, query_properties=query_properties,
                         max_concurrency=concurrency, apply_times_daily=apply_times_daily, symbol_date=symbol_date,
                         query_params=query_params, time_as_nsec=time_as_nsec,
                         treat_byte_arrays_as_strings=treat_byte_arrays_as_strings,
                         output_mode=output_mode,
                         output_matrix_per_field=output_matrix_per_field, output_structure=output_structure_for_otq,
                         return_utc_times=return_utc_times, connection=connection,
                         callback=callback, svg_path=svg_path, use_connection_pool=use_connection_pool, **kwargs)
    except Exception as e:
        raise _add_stack_info_to_exception(e)  # noqa: W0707
    if output_mode == otq.QueryOutputMode.callback:
        return result
    # node_names should be either a list of node names or None
    if isinstance(node_name, str):
        node_names = [node_name]
    else:
        node_names = node_name
    return _format_call_output(result, output_structure=output_structure,
                               require_dict=require_dict, node_names=node_names) 
def _filter_returned_map_by_node(result, node_names):
    """
    Here, result has the following format: {symbol: {node_name: data}}
    We need to filter by correct node_name
    """
    # TODO: implement filtering by node_name in a way
    # that no information from SymbolNumpyResultMap object is lost
    return result
    # if not node_name:
    #     return result
    #
    # res = {}
    # for symbol, nodes_dict in result.items():
    #     res[symbol] = {}
    #     for node, data in nodes_dict.items():
    #         if node == node_name:
    #             res[symbol][node] = data
    # return res
def _filter_returned_list_by_node(result, node_names):
    """
    Here, result has the following format: [(symbol, data_1, data_2, node_name)]
    We need to filter by correct node_names
    """
    if not node_names:
        return result
    node_found = False
    res = []
    empty_result = True
    for symbol, data_1, data_2, node in result:
        if data_1:
            empty_result = False
        if node in node_names:
            node_found = True
            res.append((symbol, data_1, data_2, node))
    if not empty_result and not node_found:
        # TODO: Do we even want to raise it?
        raise Exception(f'No passed node name(s) were found in the results. Passed node names were: {node_names}')
    return res
def _form_dict_from_list(data_list, node_names=None):
    """
    Here, data_list has the following format: [(symbol, data_1, data_2, node_name)]
    We need to create the following result:
    either {symbol: pd.DataFrame(data_1)} if there is only one result per symbol
    or {symbol: [pd.DataFrame(data_1)]} if there are multiple results for symbol for a single node_name
    or {symbol: {node_name: pd.DataFrame(data_1)}} if there are single results for multiple node names for a symbol
    or {symbol: {node_name: [pd.DataFrame(data_1)]}} if there are multiple results for multiple node names for a symbol
    """
    def reduce_list(lst):
        if len(lst) == 1:
            return lst[0][1]
        elif node_names and len(node_names) == 1:
            return list(map(lambda i: i[1], lst))
        else:
            return lst
    def form_node_name_dict(lst):
        """
        lst is a lit of (node, dataframe)
        """
        d = defaultdict(list)
        for node, df in lst:
            d[node].append(df)
        for node in d.keys():  # noqa
            if len(d[node]) == 1:
                d[node] = d[node][0]
        if len(d) == 1:
            d = list(d.values())[0]
        else:  # converting defaultdict to regular dict
            d = dict(d)
        return d
    def get_dataframe(data):
        return pd.DataFrame({col_name: col_value for col_name, col_value in data})
    symbols_dict = defaultdict(list)
    for symbol, data, _, node in data_list:
        df = get_dataframe(data)
        list_item = (node, df)
        symbols_dict[symbol].append(list_item)
    for symbol, lst in symbols_dict.items():
        symbols_dict[symbol] = form_node_name_dict(lst)
    return dict(symbols_dict)
def _format_call_output(result, output_structure, node_names, require_dict):
    """Formats output of otq.run() according to passed parameters.
    See parameters' description for more information
    Parameters
    ----------
    output_structure: ['df', 'list', 'map']
        If 'df': forms pandas.DataFrame from the result.
        Returns a dictionary with symbols as keys if there's more than one symbol
        in returned data of if require_dict = True.
        Values of the returned dictionary, or returned value itself if no dictionary is formed,
        is either a list of tuples: (node_name, dataframe) if there's output for more than one node
        or a dataframe
        If 'list' or 'map': returns data as returned by otq.run(), possibly filtered by node_name (see below)
    node_names: str, None
        If not None, then selects only output returned by nodes in node_names list
        for all output structures
    require_dict: bool
        If True, forces output for output_structure='df' to always be a dictionary, even if only one symbol is returned
        Has no effect for other values of output_structure
    Returns
    ----------
        Formatted output: pandas DataFrame, dictionary or list
    """
    if output_structure == 'list':
        return _filter_returned_list_by_node(result, node_names)
    elif output_structure == 'map':
        return _filter_returned_map_by_node(result, node_names)
    assert output_structure == 'df', f'Output structure should be one of: "df", "map", "list", ' \
                                     f'instead "{output_structure}" was passed'
    # "df" output structure implies that raw results came as a list
    result_list = _filter_returned_list_by_node(result, node_names)
    result_dict = _form_dict_from_list(result_list, node_names)
    if len(result_dict) == 1 and not require_dict:
        return list(result_dict.values())[0]
    else:
        return result_dict
def _preprocess_otp_query(query, query_params):
    if isinstance(query, otp.query._outputs):
        query = query['OUT']
    if isinstance(query, otp.query):
        if query.params:
            if query_params:
                raise ValueError("please specify parameters in query or in otp.run only")
            query_params = query.params
        query = query.path
    return query, query_params
def _get_start_end(start, end, timezone, use_pyomd_timeval=True):
    def support_nanoseconds(time):
        if isinstance(time, (pd.Timestamp, otp.datetime)) and use_pyomd_timeval:
            time = pyomd.timeval_t(pyomd.OT_time_nsec(time2nsectime(time, timezone)))
        return time
    # `isinstance(obj, datetime.date)` is not correct because
    # isinstance(<datetime.datetime object>, datetime.date) = True
    if type(start) is datetime.date:
        start = datetime.datetime(start.year, start.month, start.day)
    if type(end) is datetime.date:
        end = datetime.datetime(end.year, end.month, end.day)
    start = configuration.config.default_start_time if start is utils.adaptive else support_nanoseconds(start)
    end = configuration.config.default_end_time if end is utils.adaptive else support_nanoseconds(end)
    return start, end
def _process_output_structure(output_structure):
    if not output_structure or output_structure == "df":  # otq doesn't support df
        output_structure = "df"
        output_structure_for_otq = "symbol_result_list"
    elif output_structure == "list":
        output_structure_for_otq = "symbol_result_list"
    elif output_structure == "map":
        output_structure_for_otq = "symbol_result_map"
    else:
        raise ValueError("output_structure support only the following values: df, list and map")
    return output_structure, output_structure_for_otq
class LogCallback(otp.CallbackBase):
    def __init__(self, query_name):
        print(f'Running query {query_name}')
        super().__init__()
    def process_symbol_name(self, symbol_name):
        print(f'Processing symbol {symbol_name}')