#  Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
import pandas as pd

TABLE_TYPE_NEXT_VALUE_SEPARATOR = '__pydev_table_column_type_val__'
MAX_COLWIDTH_PYTHON_2 = 100000
BATCH_SIZE = 10000


def get_type(table):
    # type: (str) -> str
    return str(type(table))


# noinspection PyUnresolvedReferences
def get_shape(table):
     # type: (datasets.arrow_dataset.Dataset) -> str
    return str(table.shape[0])


# noinspection PyUnresolvedReferences
def get_head(table):
     # type: (datasets.arrow_dataset.Dataset) -> str
    return repr(__convert_to_df(table.select([0])).head().to_html(notebook=True, max_cols=None))


# noinspection PyUnresolvedReferences
def get_column_types(table):
     # type: (datasets.arrow_dataset.Dataset) -> str
    table = __convert_to_df(table.select([0]))
    return str(table.index.dtype) + TABLE_TYPE_NEXT_VALUE_SEPARATOR + \
            TABLE_TYPE_NEXT_VALUE_SEPARATOR.join([str(t) for t in table.dtypes])


# used by pydevd
# noinspection PyUnresolvedReferences
def get_data(table, start_index=None, end_index=None):
     # type: (datasets.arrow_dataset.Dataset, int, int) -> str

    def convert_data_to_html(data, max_cols):
        return repr(data.to_html(notebook=True, max_cols=max_cols))

    return _compute_sliced_data(table, convert_data_to_html, start_index, end_index)


# used by DSTableCommands
# noinspection PyUnresolvedReferences
def display_data(table, start_index, end_index):
     # type: (datasets.arrow_dataset.Dataset, int, int) -> None
    def ipython_display(data, max_cols):
        from IPython.display import display
        display(data)

    _compute_sliced_data(table, ipython_display, start_index, end_index)


def __get_data_slice(table, start, end):
    # type: (datasets.arrow_dataset.Dataset, int, int) -> pd.DataFrame
    return __convert_to_df(table).iloc[start:end]


def _compute_sliced_data(table, fun, start_index=None, end_index=None):
    # type: (datasets.arrow_dataset.Dataset, function, int, int) -> str
    max_cols, max_colwidth, max_rows = __get_tables_display_options()

    _jb_max_cols = pd.get_option('display.max_columns')
    _jb_max_colwidth = pd.get_option('display.max_colwidth')
    _jb_max_rows = pd.get_option('display.max_rows')

    pd.set_option('display.max_columns', max_cols)
    pd.set_option('display.max_rows', max_rows)
    pd.set_option('display.max_colwidth', max_colwidth)

    if start_index is not None and end_index is not None:
        table = __get_data_slice(table, start_index, end_index)
    else:
        table = __convert_to_df(table)

    data = fun(table, max_cols)

    pd.set_option('display.max_columns', _jb_max_cols)
    pd.set_option('display.max_colwidth', _jb_max_colwidth)
    pd.set_option('display.max_rows', _jb_max_rows)

    return data


# In old versions of pandas max_colwidth accepted only Int-s
def __get_tables_display_options():
    # type: () -> Tuple[None, Union[int, None], None]
    import sys
    if sys.version_info < (3, 0):
        return None, MAX_COLWIDTH_PYTHON_2, None
    try:
        import pandas as pd
        if int(pd.__version__.split('.')[0]) < 1:
            return None, MAX_COLWIDTH_PYTHON_2, None
    except ImportError:
        pass
    return None, None, None


# noinspection PyUnresolvedReferences
def __convert_to_df(table):
    # type: (datasets.arrow_dataset.Dataset) -> pd.DataFrame
    try:
        import datasets
        if type(table) is datasets.arrow_dataset.Dataset:
            return __dataset_to_df(table)
    except ImportError as e:
        pass
    return table


def __dataset_to_df(dataset):
    # type: (datasets.arrow_dataset.Dataset) -> pd.DataFrame
    try:
        dataset_as_df = list(dataset.to_pandas(batched=True, batch_size=min(len(dataset), BATCH_SIZE)))
        if len(dataset_as_df) > 1:
            return pd.concat(dataset_as_df, ignore_index=True)
        else:
            return dataset_as_df[0]
    except ImportError as e:
        pass