"""
@desc 数据库操作方法封装
@auth  chenkai
@date 2020/11/19
@py_version py3.6
"""
import pymysql
import logging as log
import pandas as pd
import time
from model.log import logger
log = logger()
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
MYSQL_DEBUG = 1


class MysqlOperation:

    def __init__(self, host, user, passwd, db, port=3306):
        try:
            self.conn = pymysql.connect(host=host,
                                        user=user,
                                        passwd=passwd,
                                        db=db,
                                        charset='utf8mb4',
                                        port=port)
            self.cursor = self.conn.cursor()
        except Exception as e:
            log.info(e)


    def set_dict_cursor(self):
        """
        设置字典形式取数据
        """
        self.cursor = self.conn.cursor(pymysql.cursors.DictCursor)
    
    def getData(self, sql, args=None):
        """

        :param sql:
        :param args:
        :return: tuple(tuple)
        """

        start = time.time()
        self.cursor.execute(sql, args=args)
        result = self.cursor.fetchall()
        if MYSQL_DEBUG:
            sql_str = sql % tuple(args) if args else sql
            log.info('sql: \n' + sql_str)
            log.info('sql cost: %s' % (time.time() - start))
        return result

    def get_data_list(self,sql,arg=None):
        """
        :param sql:
        :param arg:
        :return: list[list]
        """
        data=self.getData(sql,arg)
        li=[]
        for i in data:
            li.append(list(i))
        return li

    def getDataOneList(self,sql):
        """获取一列"""
        data = self.getData(sql)
        li = []
        for i in data:
            li.append(i[0])
        return li


    def execute(self, sql,data=None):
        start = time.time()
        if data:
            k=self.cursor.execute(sql,data)
        else:
            k=self.cursor.execute(sql)
        self.conn.commit()
        # if MYSQL_DEBUG:
        #
        #     # log.info('sql: \n' + sql)
        #     log.info('sql cost: %s' % (time.time() - start))
        print(f"affect rows :{k}")



    def executeMany(self,sql,data):
        start = time.time()
        k=self.cursor.executemany(sql,data)
        self.conn.commit()
        # if MYSQL_DEBUG:
        #     log.info('sql: \n' + sql)
        #     log.info('sql cost: %s' % (time.time() - start))
        print(f"\033[1;36maffect rows :{k} \033[0m")

    def getOne(self,sql, args=None):
        result = self.getData(sql, args)

        return result[0][0]

    def getData_pd(self, sql, args=None):
        start = time.time()
        # if args:
        #     log.debug(sql % tuple(args))
        # else:
        #     log.debug(sql)
        self.cursor.execute(sql, args=args)
        num_fields = len(self.cursor.description)
        field_names = [i[0] for i in self.cursor.description]
        df = self.cursor.fetchall()

        df = pd.DataFrame(data=list(df), columns=field_names)

        if MYSQL_DEBUG:
            sql_str = sql % tuple(args) if args else sql
            log.info('sql: \n' + sql_str)
            log.info('sql cost: %s' % (time.time() - start))
        return df

    # def insertData(self, sql, args=None):
    #     # if args:
    #     #     log.debug(sql % tuple(args))
    #     # else:
    #     #     log.debug(sql)
    #     start = time.time()
    #     self.cursor.execute(sql, args=args)
    #
    #     if MYSQL_DEBUG:
    #         sql_str = sql % tuple(args) if args else sql
    #         log.info('sql: \n' + sql_str)
    #         log.info('sql cost: %s' % (time.time() - start))
    #     self.conn.commit()

    def executeWithoutCommit(self, sql, args=None):
        return self.cursor.execute(sql, args=args)

    def commit(self):
        self.conn.commit()

    def insertorupdate(self, table, keys, tags, tagvalue, flag, *args):
        """
        :param table: 表名
        :param keys: 联合主键名元组
        :param tags: 字段名元组
        :param tagvalue: 字段值
        :param args: 主键值
        :param flag: 控制是否打印日志
        :return:

        """

        # log.info(tags)
        sql = "INSERT INTO " + table + " ("
        sql += ",".join(keys) + ","
        sql += ",".join(tags)
        sql += ") SELECT "
        sql += "%s," * len(keys)
        sql += ("%s," * len(tags))[:-1]
        sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table
        sql += " WHERE "
        for _ in keys:
            sql += _ + "=%s AND "

        sql = sql[:-4]
        sql += "LIMIT 1)"
        arg = list(args)
        arg.extend(tagvalue)
        arg.extend(list(args))

        rows = self.cursor.execute(sql, args=arg)
        if rows == 0:
            sql = "UPDATE " + table + " SET "
            for _ in tags:
                sql += _ + "=%s,"
            sql = sql[:-1]
            sql += " WHERE "
            for _ in keys:
                sql += _ + "=%s AND "
            sql = sql[:-4]
            arg = []
            arg.extend(tagvalue)
            arg.extend(list(args))
            self.cursor.execute(sql, args=arg)
        if flag:
            log.info(sql % tuple(arg))
        self.conn.commit()

    def _insertorupdate(self, table, keys, tags, tag_value, flag, key_value, update=False):

        if not update:
            sql = "INSERT INTO " + table + " ("
            sql += ",".join(keys) + ","
            sql += ",".join(tags)
            sql += ") SELECT "
            sql += "%s," * len(keys)
            sql += ("%s," * len(tags))[:-1]
            sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table
            sql += " WHERE "
            for _ in keys:
                sql += _ + "=%s AND "

            sql = sql[:-4]
            sql += "LIMIT 1)"
            arg = list(key_value)
            arg.extend(tag_value)
            arg.extend(list(key_value))

            rows = self.cursor.execute(sql, args=arg)
            if rows == 0:
                sql = "UPDATE " + table + " SET "
                for _ in tags:
                    sql += _ + "=%s,"
                sql = sql[:-1]
                sql += " WHERE "
                for _ in keys:
                    sql += _ + "=%s AND "
                sql = sql[:-4]
                arg = []
                arg.extend(tag_value)
                arg.extend(list(key_value))
                self.cursor.execute(sql, args=arg)
            if flag:
                log.info(sql % tuple(arg))
        else:

            sql = "UPDATE " + table + " SET "
            for _ in tags:
                sql += _ + "=%s,"
            sql = sql[:-1]
            sql += " WHERE "
            for _ in keys:
                sql += _ + "=%s AND "
            sql = sql[:-4]
            arg = []
            arg.extend(tag_value)
            arg.extend(list(key_value))
            self.cursor.execute(sql, args=arg)

            if flag:
                log.info(sql % tuple(arg))

    def _insert_on_duplicate(self, table, keys, tags, tag_value, flag, key_value):
        name_all = list(keys)
        name_all.extend(tags)
        arg = list(key_value)
        arg.extend(tag_value)
        arg.extend(tag_value)
        sql_name = '(' + ','.join(name_all) + ')'
        sql_value = '(' + ','.join(['%s'] * len(name_all)) + ')'
        sql_update = ','.join([_ + '=%s' for _ in tags])
        sql = """
            insert into %s
            %s
            VALUES %s
            ON duplicate key UPDATE %s
        """ % (table, sql_name, sql_value, sql_update)
        self.cursor.execute(sql, args=arg)
        if flag:
            log.debug(sql % tuple(arg))

    def insertorupdatemany(self, table, keys, tags, tag_values, key_values, flag=False, unique_key=False, update=False):
        """
        :param table: 表名
        :param keys: 联合主键名元组
        :param tags: 字段名元组
        :param tag_values: 字段值组(list or pd.DataFrame)
        :param key_values: 主键值组(list or pd.DataFrame)
        :param flag: 控制是否打印日志
        :param unique_key: keys 是否为table的 unique_key
        :return:
        ps: 效率(外网): rows / 50;  1000以上更新使用
        """
        if isinstance(tag_values, pd.DataFrame):
            list_tag_value = [list(tag_values.iloc[_, :]) for _ in range(len(tag_values))]
        else:
            list_tag_value = list(tag_values)
        if isinstance(key_values, pd.DataFrame):
            list_key_value = [list(key_values.iloc[_, :]) for _ in range(len(key_values))]
        else:
            list_key_value = list(key_values)
        for _ in range(len(list_tag_value)):
            tag_value = list_tag_value[_]
            key_value = list_key_value[_]
            if unique_key:
                self._insert_on_duplicate(table, keys, tags, tag_value, flag, key_value)
            else:
                self._insertorupdate(table, keys, tags, tag_value, flag, key_value, update)
        self.conn.commit()

    def _check_repeat_key(self, key_list):
        tmp = list(map(lambda x: tuple(x), key_list))
        if len(tmp) == len(set(tmp)):
            return False
        else:
            last_data = -1
            repeat_key = set()
            for i in sorted(tmp):
                if last_data == i:
                    repeat_key.add(i)
                if len(repeat_key) >= 10:
                    break
                last_data = i
            log.error('Reject repeated keys')
            log.error('repeat_key: %s' % repeat_key)
            return True

    def _convert_to_list(self, data):
        if isinstance(data, pd.DataFrame):
            # np.nan != np.nan 从而判断值为np.nan
            list_data = [map(lambda x: None if x != x else x, list(data.iloc[_, :])) for _ in range(len(data))]
            li =[]
            for  i in list_data:
                li.append(list(i))
            list_data = li

        else:
            list_data = list(data)

        return list_data

    def _get_exist_keys_index(self, table, keys, key_values, flag=False):
        list_sql_when = []
        list_tmp = []

        for i in range(len(key_values)):
            sql_when = """when (%s)=(%s) then %s""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])), i)
            list_sql_when.append(sql_when)
            list_tmp.extend(key_values[i])
        list_sql_condition = []
        for i in range(len(key_values)):
            # sql_condition_old = """(%s)=(%s)""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])))
            row_condition_list = map(lambda x: '%s = %%s' % x, keys)
            sql_condition = """(%s)""" % ' and '.join(row_condition_list)
            # print sql_condition_old, sql_condition
            list_sql_condition.append(sql_condition)
            list_tmp.extend(key_values[i])
        sql_where = ' or '.join(list_sql_condition)

        sql_case = '\n'.join(list_sql_when)
        sql = """
            select
            case
                %s
            end
            from %s
            where %s
        """ % (sql_case, table, sql_where)
        if flag:
            log.info(sql % tuple(list_tmp))

        self.cursor.execute(sql, tuple(list_tmp))
        print()
        result = self.cursor.fetchall()
        return map(lambda x: x[0], result)

    def insertorupdatemany_v2(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
        """
        更新插入多条数据(无key时自动插入, 有keys时更新)
        :param table: 表名
        :param keys: 联合主键名元组
        :param tags: 字段名元组
        :param tag_values: 字段值组(list or pd.DataFrame)
        :param key_values: 主键值组(list or pd.DataFrame)
        :param flag: 控制是否打印日志
        :param split: 切割阈值
        :return:

        ps: 效率(外网): rows^2 / 50000;  rows以split为单位分批更新
        """
        if not isinstance(tag_values, (tuple, list, pd.DataFrame)):
            log.error('Type Error')
            exit(-1)
            return
        if len(tag_values) > split:
            length = len(tag_values)
            for i in range(0, length, split):
                start, finish = i, i + split
                self.insertorupdatemany_v2(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
            return
        if len(key_values) == 0 or len(tag_values) == 0:
            log.debug('insert or update 0 rows')
            return
        tag_values = self._convert_to_list(tag_values)
        key_values = self._convert_to_list(key_values)
        assert self._check_repeat_key(key_values) == False

        exist_key_index = list(self._get_exist_keys_index(table, keys, key_values, flag))

        new_key_index = list(set(range(len(key_values))) - set(exist_key_index))
        update_keys = list(map(lambda x: key_values[x], exist_key_index))
        update_tags = list(map(lambda x: tag_values[x], exist_key_index))
        insert_keys = list(map(lambda x: key_values[x], new_key_index))
        insert_tags = list(map(lambda x: tag_values[x], new_key_index))

        self.insert_many(table=table,
                         keys=keys,
                         tags=tags,
                         tag_values=insert_tags,
                         key_values=insert_keys,
                         flag=flag)

        self.update_many(table=table,
                         keys=keys,
                         tags=tags,
                         tag_values=update_tags,
                         key_values=update_keys,
                         flag=flag,
                         split=split)

    def insertorupdatemany_v3(self, df, table, keys, tags, flag=False, split=80):
        self.insertorupdatemany_v2(
            table=table,
            keys=keys,
            tags=tags,
            tag_values=df[tags],
            key_values=df[keys],
            flag=flag,
            split=split
        )

    def _get_s_format(self, data):
        """
        Args:
            data: [[featureA1, featureB1, ...], [featureA2, featureB2, ...], ...]

        Returns:
            format of %s and real value

        Example:
            [['2017-07-01', 78], ['2017-07-01', 1]] ->
                     ('((%s, %s), (%s, %s))', ['2017-07-01', 78, '2017-07-01', 1])
        """
        list_tmp_s = []
        values = []
        for _ in data:
            tmp_s = ','.join(len(_) * ['%s'])
            values.extend(_)
            if len(_) > 1:
                tmp_s = '(' + tmp_s + ')'
            list_tmp_s.append(tmp_s)
        format_s = '(' + ','.join(list_tmp_s) + ')'
        return format_s, values

    def delete_by_key(self, table, keys, key_values, flag=False):
        """

        Args:
            table: 表名
            keys: 联合主键名元组
            key_values: 主键值组(list or pd.DataFrame)
            flag: 控制是否打印日志

        Examples:
            delete_by_key('table_test', keys=['date'], key_values=[['2017-07-01'], ['2017-07-02']], flag=False)
            delete_by_key('table_test', keys=['date'], key_values=['2017-07-01'], flag=False)
        """
        if len(key_values) == 0:
            return
        if not (isinstance(key_values[0], (list, tuple)) or isinstance(key_values, pd.DataFrame)):
            key_values_list = [key_values]
        else:
            key_values_list = self._convert_to_list(key_values)
        sql_keys = '(' + ','.join(keys) + ')'

        contact_s, values_s = self._get_s_format(key_values_list)
        sql_del = """
            delete from %s
            where %s in %s
        """ % (table, sql_keys, contact_s)
        if flag:
            log.debug(sql_del % tuple(values_s))
        self.cursor.execute(sql_del, tuple(values_s))
        self.conn.commit()

    def insert_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
        """
        直接插入多条数据
        :param table: 表名
        :param keys: 联合主键名元组
        :param tags: 字段名元组
        :param tag_values: 字段值组(list or pd.DataFrame)
        :param key_values: 主键值组(list or pd.DataFrame)
        :param flag: 控制是否打印日志
        :return:

        Examples: 参照 insertorupdatemany_v2
        insert into table
        (count_date, cid, tag1, tag2)
        values ('2017-01-01', 10, 1, 'a'), ('2017-01-02', 20, 2, 'b'), ...
        """
        if len(key_values) == 0 or len(tag_values) == 0:
            log.debug('insert 0 rows')
            return
        if len(tag_values) > split:
            length = len(tag_values)
            for i in range(0, length, split):
                start, finish = i, i + split
                self.insert_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
            return
        tag_values = self._convert_to_list(tag_values)
        key_values = self._convert_to_list(key_values)

        feature_total = "(" + ",".join(keys + tags) + ")"
        tmp_s = "(" + ",".join(["%s"] * len(keys + tags)) + ")"
        tmp_s_concat = ",\n".join([tmp_s] * len(key_values))
        sql_insert = """
                Insert into %s
                %s
                values %s""" % (table, feature_total, tmp_s_concat)
        value_insert = []
        for _ in zip(key_values, tag_values):
            value_insert.extend(_[0] + _[1])
        if flag:
            log.debug(sql_insert % tuple(value_insert))
        t0 = time.time()

        self.cursor.execute(sql_insert,tuple(value_insert))
        log.debug('insert %s rows, cost: %s' % (len(key_values), time.time() - t0))
        self.conn.commit()

    def update_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
        """
        更新多条数据(无key时不会自动插入)
        :param table: 表名
        :param keys: 联合主键名元组
        :param tags: 字段名元组
        :param tag_values: 字段值组(list or pd.DataFrame)
        :param key_values: 主键值组(list or pd.DataFrame)
        :param flag: 控制是否打印日志
        :param split: 分批更新量
        :return:

        Examples: 参照 insertorupdatemany_v2
        # 单条 update sql tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10)
        update table
        set tag1=1, tag2='a'
        where (count_date, cid) =('2017-01-01', 10)

        # 多条组合 update sql
        # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10);
        # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10);
        update table
        set tag1 = case 
            when (count_date, cid)=('2017-01-01', 10) then 1
            when (count_date, cid)=('2017-01-02', 20) then 2
            ...
            ,
            tag_2 = case
            when (count_date, cid)=('2017-01-01', 10) then 'a'
            when (count_date, cid)=('2017-01-02', 20) then 'b'
            ...
        where (count_date, cid)=('2017-01-01', 10) or (count_date, cid)=('2017-01-02', 20) or ...

        """
        if len(tag_values) > split:
            length = len(tag_values)
            for i in range(0, length, split):
                start, finish = i, i + split
                self.update_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
            return
        if len(key_values) == 0 or len(tag_values) == 0:
            log.debug('update 0 rows')
            return

        tag_values = self._convert_to_list(tag_values)
        key_values = self._convert_to_list(key_values)

        if self._check_repeat_key(key_values):
            return

        update_value = []
        sql_keys = ','.join(keys)
        if len(keys) > 1:
            sql_keys = '(' + sql_keys + ')'

        sql_key_values = ','.join(['%s'] * len(keys))
        if len(keys) > 1:
            sql_key_values = '(' + sql_key_values + ')'

        sql_set_list = []
        for i in range(len(tags)):
            sql_when_list = []
            for j in range(len(tag_values)):
                sql_when = """when %s=%s then %s """ % (sql_keys, sql_key_values, '%s')
                update_value.extend(key_values[j])
                update_value.append(tag_values[j][i])
                sql_when_list.append(sql_when)
            sql_when_concat = '\n\t'.join(sql_when_list)
            sql_set = """%s = case \n\t %s\n end""" % (tags[i], sql_when_concat)
            sql_set_list.append(sql_set)
        for _ in key_values:
            update_value.extend(_)
        sql_set_concat = ',\n'.join(sql_set_list)

        list_sql_condition = []
        for i in range(len(key_values)):
            row_condition_list = map(lambda x: '%s = %%s' % x, keys)
            sql_condition = """(%s)""" % ' and '.join(row_condition_list)
            list_sql_condition.append(sql_condition)
        sql_where = ' or '.join(list_sql_condition)

        # condition = ' or\n\t'.join([sql_keys + '=' + sql_key_values] * len(tag_values))
        # print condition
        sql = """update %s\n set %s\n where %s""" % (table, sql_set_concat, sql_where)
        if flag:
            log.info(sql % tuple(update_value))
        t0 = time.time()
        self.cursor.execute(sql, tuple(update_value))
        self.conn.commit()
        log.debug('update %s rows, cost: %s' % (len(key_values), time.time() - t0))


    def getColumn(self,table,flag=0):
        "获取表的所有列"
        sql="SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` " \
            "WHERE `TABLE_NAME`='{}' ORDER BY ordinal_position".format(table)
        self.cursor.execute(sql)
        a= self.cursor.fetchall()
        str=''
        li=[]
        for i in a:
            str+=i[0]+','
            li.append(i[0])

        if flag:
            return li
        else:
            return str[:-1]