# coding: utf-8

# Copyright 2014-2025 Álvaro Justen <https://github.com/turicas/rows/>
#    This program is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
#    Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
#    any later version.
#    This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
#    warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for
#    more details.
#    You should have received a copy of the GNU Lesser General Public License along with this program.  If not, see
#    <http://www.gnu.org/licenses/>.

from __future__ import unicode_literals

import itertools
import random
import types
import unittest
from collections import OrderedDict

import mock

import rows
import rows.plugins.utils as plugins_utils
import tests.utils as utils
from rows import fields
from rows.compat import TEXT_TYPE


class GenericUtilsTestCase(unittest.TestCase):
    def test_ipartition(self):
        iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        result = plugins_utils.ipartition(iterable, 3)
        assert type(result) == types.GeneratorType
        assert list(result) == [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]

        result = plugins_utils.ipartition(iterable, 2)
        assert type(result) == types.GeneratorType
        assert list(result) == [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]


def possible_field_names_errors(error_fields):
    error_fields = ['"{}"'.format(field_name) for field_name in error_fields]
    fields_permutations = itertools.permutations(error_fields, len(error_fields))
    fields_permutations_str = [
        ", ".join(field_names) for field_names in fields_permutations
    ]
    return [
        "Invalid field names: {}".format(field_names)
        for field_names in fields_permutations_str
    ]


class PluginUtilsTestCase(utils.RowsTestMixIn, unittest.TestCase):
    def test_create_table_skip_header(self):
        field_types = OrderedDict(
            [("integer", fields.IntegerField), ("string", fields.TextField)]
        )
        data = [["1", "Álvaro"], ["2", "turicas"], ["3", "Justen"]]
        table_1 = plugins_utils.create_table(data, fields=field_types, skip_header=True)
        table_2 = plugins_utils.create_table(
            data, fields=field_types, skip_header=False
        )

        assert field_types == table_1.fields
        assert table_1.fields == table_2.fields
        assert len(table_1) == 2
        assert len(table_2) == 3

        first_row = {"integer": 1, "string": "Álvaro"}
        second_row = {"integer": 2, "string": "turicas"}
        third_row = {"integer": 3, "string": "Justen"}
        assert dict(table_1[0]._asdict()) == second_row
        assert dict(table_2[0]._asdict()) == first_row
        assert dict(table_1[1]._asdict()) == third_row
        assert dict(table_2[1]._asdict()) == second_row
        assert dict(table_2[2]._asdict()) == third_row

    def test_create_table_import_fields(self):
        header = ["field1", "field2", "field3"]
        table_rows = [
            ["1", 3.14, "Álvaro"],
            ["2", 2.71, "turicas"],
            ["3", 1.23, "Justen"],
        ]
        table = plugins_utils.create_table([header] + table_rows, import_fields=None)
        assert list(table.fields.keys()) == header
        assert table[0].field1 == 1
        assert table[0].field2 == 3.14
        assert table[0].field3 == "Álvaro"

        import_fields = ["field3", "field2"]
        table = plugins_utils.create_table(
            [header] + table_rows, import_fields=import_fields
        )
        assert list(table.fields.keys()) == import_fields
        assert table[0]._asdict() == OrderedDict([("field3", "Álvaro"), ("field2", 3.14)])

    def test_create_table_import_fields_ordering(self):
        # From: https://github.com/turicas/rows/issues/239

        data = [
            ["intfield", "textfield", "floatfield"],
            [1, "str1", 1.2],
            [2, "str2", 2.3],
            [3, "str3", 3.4],
        ]
        # `fields` parameter on `create_table` must always be in the same order
        # as the data.
        fields = OrderedDict(
            [
                ("intfield", rows.fields.IntegerField),
                ("textfield", rows.fields.TextField),
                ("floatfield", rows.fields.FloatField),
            ]
        )

        # Regular case: no `import_fields` specified
        table = plugins_utils.create_table(data, fields=fields, skip_header=True)
        assert table.fields == fields
        for row, row_data in zip(table, data[1:]):
            assert row_data == [row.intfield, row.textfield, row.floatfield]

        # Special case: `import_fields` has different order from `fields`
        import_fields = ["textfield", "intfield"]
        table = plugins_utils.create_table(
            data, fields=fields, import_fields=import_fields, skip_header=True
        )
        assert list(table.fields.keys()) == import_fields
        for row, row_data in zip(table, data[1:]):
            assert row_data[1] == row.textfield
            assert row_data[0] == row.intfield

    def test_create_table_import_fields_dont_exist(self):
        header = ["field1", "field2", "field3"]
        table_rows = [
            ["1", 3.14, "Álvaro"],
            ["2", 2.71, "turicas"],
            ["3", 1.23, "Justen"],
        ]

        error_fields = ["doesnt_exist", "ruby"]
        import_fields = list(header)[:-1] + error_fields
        with self.assertRaises(ValueError) as exception_context:
            plugins_utils.create_table(
                [header] + table_rows, import_fields=import_fields
            )

        assert exception_context.exception.args[0] in possible_field_names_errors(error_fields)

    def test_create_table_repeated_field_names(self):
        header = ["first", "first", "first"]
        table_rows = [
            ["1", 3.14, "Álvaro"],
            ["2", 2.71, "turicas"],
            ["3", 1.23, "Justen"],
        ]
        table = plugins_utils.create_table([header] + table_rows)
        assert list(table.fields.keys()) == ["first", "first_2", "first_3"]
        assert table[0].first == 1
        assert table[0].first_2 == 3.14
        assert table[0].first_3 == "Álvaro"

        header = ["field", "", "field"]
        table_rows = [
            ["1", 3.14, "Álvaro"],
            ["2", 2.71, "turicas"],
            ["3", 1.23, "Justen"],
        ]
        table = plugins_utils.create_table([header] + table_rows)
        assert list(table.fields.keys()) == ["field", "field_1", "field_2"]
        assert table[0].field == 1
        assert table[0].field_1 == 3.14
        assert table[0].field_2 == "Álvaro"

    def test_create_table_empty_data(self):
        header = ["first", "first", "first"]
        table_rows = []
        table = plugins_utils.create_table([header] + table_rows)
        assert list(table.fields.keys()) == ["first", "first_2", "first_3"]
        assert len(table) == 0

    def test_create_table_force_types(self):
        header = ["field1", "field2", "field3"]
        table_rows = [
            ["1", "3.14", "Álvaro"],
            ["2", "2.71", "turicas"],
            ["3", "1.23", "Justen"],
        ]
        force_types = {"field2": rows.fields.DecimalField}

        table = plugins_utils.create_table(
            [header] + table_rows, force_types=force_types
        )
        for field_name, field_type in force_types.items():
            assert table.fields[field_name] == field_type

    def test_create_table_different_number_of_fields(self):
        header = ["field1", "field2"]
        table_rows = [
            ["1", "3.14", "Álvaro"],
            ["2", "2.71", "turicas"],
            ["3", "1.23", "Justen"],
        ]
        table = plugins_utils.create_table([header] + table_rows)
        assert list(table.fields.keys()) == ["field1", "field2", "field_2"]
        assert table[0].field1 == 1
        assert table[0].field2 == 3.14
        assert table[0].field_2 == "Álvaro"
        assert table[1].field1 == 2
        assert table[1].field2 == 2.71
        assert table[1].field_2 == "turicas"
        assert table[2].field1 == 3
        assert table[2].field2 == 1.23
        assert table[2].field_2 == "Justen"

    def test_create_table_optimization_is_the_same_as_extending(self):
        # TODO: do the same for FlexibleTable when `create_table` accepts it (or other kinds of table classes)
        header = ["f1", "f2", "f3"]
        table_rows = [
            ["1", "3.14", "Álvaro"],
            ["2", "2.71", "turicas"],
            ["3", "1.23", "Justen"],
        ]
        table_1 = plugins_utils.create_table([header] + table_rows)
        table_2 = rows.Table(fields=table_1.fields.copy())
        table_2.extend(dict(zip(header, row)) for row in table_rows)
        assert table_1._rows == table_2._rows

    def test_create_table_optimization_is_the_same_as_extending_custom_fields(self):
        # TODO: do the same for FlexibleTable when `create_table` accepts it (or other kinds of table classes)
        header = ["f1", "f2", "f3"]
        table_rows = [
            ["1", "3.14", "Álvaro"],
            ["2", "2.71", "turicas"],
            ["3", "1.23", "Justen"],
        ]
        table_1 = plugins_utils.create_table([header] + table_rows, import_fields=["f3", "f2"])
        table_2 = rows.Table(fields=table_1.fields.copy())
        table_2.extend({"f3": row[2], "f2": row[1]} for row in table_rows)
        assert table_1._rows == table_2._rows

    def test_prepare_to_export_all_fields(self):
        result = plugins_utils.prepare_to_export(utils.table, export_fields=None)

        assert tuple(utils.table.fields.keys()) == next(result)

        for row in utils.table._rows:
            assert row == next(result)

        with self.assertRaises(StopIteration):
            next(result)

    def test_prepare_to_export_some_fields(self):
        field_names = list(utils.table.fields.keys())
        number_of_fields = random.randint(2, len(field_names) - 1)
        some_fields = [field_names[index] for index in range(number_of_fields)]
        random.shuffle(some_fields)
        result = plugins_utils.prepare_to_export(utils.table, export_fields=some_fields)

        assert tuple(some_fields) == next(result)

        for row in utils.table:
            expected_row = tuple([getattr(row, field_name) for field_name in some_fields])
            assert expected_row == next(result)

        with self.assertRaises(StopIteration):
            next(result)

    def test_prepare_to_export_some_fields_dont_exist(self):
        field_names = list(utils.table.fields.keys())
        error_fields = ["does_not_exist", "java"]
        export_fields = field_names + error_fields
        result = plugins_utils.prepare_to_export(
            utils.table, export_fields=export_fields
        )
        with self.assertRaises(ValueError) as exception_context:
            next(result)

        assert exception_context.exception.args[0] in possible_field_names_errors(error_fields)

    def test_prepare_to_export_with_FlexibleTable(self):
        flexible = rows.FlexibleTable()
        for row in utils.table:
            flexible.append(row._asdict())

        field_names = tuple(flexible.fields.keys())
        prepared = plugins_utils.prepare_to_export(flexible)
        assert next(prepared) == field_names

        for row, expected_row in zip(prepared, flexible._rows):
            values = tuple([expected_row[field_name] for field_name in field_names])
            assert values == row

    def test_prepare_to_export_with_FlexibleTable_and_export_fields(self):
        flexible = rows.FlexibleTable()
        for row in utils.table:
            # convertion to text_type is needed on Python 2 since namedtuples'
            # keys are bytes, not unicode
            flexible.append(
                {TEXT_TYPE(key): value for key, value in row._asdict().items()}
            )

        field_names = list(flexible.fields.keys())
        export_fields = tuple(field_names[: len(field_names) // 2])
        prepared = plugins_utils.prepare_to_export(
            flexible, export_fields=export_fields
        )
        assert next(prepared) == export_fields

        for row, expected_row in zip(prepared, flexible._rows):
            values = tuple([expected_row[field_name] for field_name in export_fields])
            assert values == row

    def test_prepare_to_export_wrong_obj_type(self):
        """`prepare_to_export` raises exception if obj isn't `*Table`"""

        with self.assertRaises(ValueError) as exception_context:
            next(plugins_utils.prepare_to_export(1))
        assert exception_context.exception.args[0] == "Table type 'int' not recognized"

        with self.assertRaises(ValueError) as exception_context:
            next(plugins_utils.prepare_to_export(42.0))
        assert exception_context.exception.args[0] == "Table type 'float' not recognized"

        with self.assertRaises(ValueError) as exception_context:
            next(plugins_utils.prepare_to_export([list("abc"), [1, 2, 3]]))
        assert exception_context.exception.args[0] == "Table type 'list' not recognized"

    @mock.patch("rows.plugins.utils.prepare_to_export", return_value=iter([[], [], []]))
    def test_serialize_should_call_prepare_to_export(self, mocked_prepare_to_export):
        table = utils.table
        kwargs = {"export_fields": 123, "other_parameter": 3.14}
        result = plugins_utils.serialize(table, **kwargs)
        self.assertFalse(mocked_prepare_to_export.called)
        field_names, table_rows = next(result), list(result)
        assert mocked_prepare_to_export.called
        assert mocked_prepare_to_export.call_count == 1
        assert mock.call(table, **kwargs) == mocked_prepare_to_export.call_args

    def test_serialize(self):
        result = plugins_utils.serialize(utils.table)
        field_types = list(utils.table.fields.values())
        assert next(result) == tuple(utils.table.fields.keys())

        for row, expected_row in zip(result, utils.table._rows):
            values = [
                field_type.serialize(value)
                for field_type, value in zip(field_types, expected_row)
            ]
            assert values == row

    def test_make_unique_name(self):
        name = "test"
        existing_names = []
        name_format = "{index}_{name}"

        result = fields.make_unique_name(name, existing_names, name_format)
        assert result == name

        existing_names = ["test"]
        result = fields.make_unique_name(name, existing_names, name_format)
        assert result == "2_test"

        existing_names = ["test", "2_test", "3_test", "5_test"]
        result = fields.make_unique_name(name, existing_names, name_format)
        assert result == "4_test"

        existing_names = ["test", "2_test", "3_test", "5_test"]
        result = fields.make_unique_name(name, existing_names, name_format, start=1)
        assert result == "1_test"

        existing_names = ["test"]
        result = fields.make_unique_name(name, existing_names, start=1, max_size=4)
        assert result == "te_1"

    # TODO: test all features of create_table
    # TODO: test if error is raised if len(row) != len(fields)
