""" @desc 数据库操作方法封装 @auth chenkai @date 2020/11/19 @py_version py3.6 """ import pymysql import logging as log import pandas as pd import time from model.log import logger log = logger() pd.set_option('display.max_columns', None) pd.set_option('display.width', 1000) MYSQL_DEBUG = 1 class MysqlOperation: def __init__(self, host, user, passwd, db, port=3306): try: self.conn = pymysql.connect(host=host, user=user, passwd=passwd, db=db, charset='utf8mb4', port=port) self.cursor = self.conn.cursor() except Exception as e: log.info(e) def set_dict_cursor(self): """ 设置字典形式取数据 """ self.cursor = self.conn.cursor(pymysql.cursors.DictCursor) def getData(self, sql, args=None): """ :param sql: :param args: :return: tuple(tuple) """ start = time.time() self.cursor.execute(sql, args=args) result = self.cursor.fetchall() if MYSQL_DEBUG: sql_str = sql % tuple(args) if args else sql log.info('sql: \n' + sql_str) log.info('sql cost: %s' % (time.time() - start)) return result def get_data_list(self,sql,arg=None): """ :param sql: :param arg: :return: list[list] """ data=self.getData(sql,arg) li=[] for i in data: li.append(list(i)) return li def execute(self, sql,args=None): start = time.time() self.cursor.execute(query=sql,args=args) self.conn.commit() if MYSQL_DEBUG: log.info('sql: \n' + sql) log.info('sql cost: %s' % (time.time() - start)) def getOne(self,sql, args=None): result = self.getData(sql, args) return result[0][0] def getOneList(self, sql, args=None): result = self.getData(sql, args) return [x[0] for x in result] def getData_pd(self, sql, args=None): start = time.time() # if args: # log.debug(sql % tuple(args)) # else: # log.debug(sql) self.cursor.execute(sql, args=args) num_fields = len(self.cursor.description) field_names = [i[0] for i in self.cursor.description] df = self.cursor.fetchall() df = pd.DataFrame(data=list(df), columns=field_names) if MYSQL_DEBUG: sql_str = sql % tuple(args) if args else sql log.info('sql: \n' + sql_str) log.info('sql cost: %s' % (time.time() - start)) return df def getData_json(self,sql): """ :param sql: :return: [{},{}] """ self.cursor.execute(sql) num_fields = len(self.cursor.description) field_names = [i[0] for i in self.cursor.description] df = self.cursor.fetchall() li=[] for i in list(df): li.append(dict(zip(field_names,i))) return li def insertData(self, sql, args=None): # if args: # log.debug(sql % tuple(args)) # else: # log.debug(sql) start = time.time() self.cursor.execute(sql, args=args) if MYSQL_DEBUG: sql_str = sql % tuple(args) if args else sql log.info('sql: \n' + sql_str) log.info('sql cost: %s' % (time.time() - start)) self.conn.commit() def executeWithoutCommit(self, sql, args=None): return self.cursor.execute(sql, args=args) def commit(self): self.conn.commit() def insertorupdate(self, table, keys, tags, tagvalue, flag, *args): """ :param table: 表名 :param keys: 联合主键名元组 :param tags: 字段名元组 :param tagvalue: 字段值 :param args: 主键值 :param flag: 控制是否打印日志 :return: """ # log.info(tags) sql = "INSERT INTO " + table + " (" sql += ",".join(keys) + "," sql += ",".join(tags) sql += ") SELECT " sql += "%s," * len(keys) sql += ("%s," * len(tags))[:-1] sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table sql += " WHERE " for _ in keys: sql += _ + "=%s AND " sql = sql[:-4] sql += "LIMIT 1)" arg = list(args) arg.extend(tagvalue) arg.extend(list(args)) rows = self.cursor.execute(sql, args=arg) if rows == 0: sql = "UPDATE " + table + " SET " for _ in tags: sql += _ + "=%s," sql = sql[:-1] sql += " WHERE " for _ in keys: sql += _ + "=%s AND " sql = sql[:-4] arg = [] arg.extend(tagvalue) arg.extend(list(args)) self.cursor.execute(sql, args=arg) if flag: log.info(sql % tuple(arg)) self.conn.commit() def _insertorupdate(self, table, keys, tags, tag_value, flag, key_value, update=False): if not update: sql = "INSERT INTO " + table + " (" sql += ",".join(keys) + "," sql += ",".join(tags) sql += ") SELECT " sql += "%s," * len(keys) sql += ("%s," * len(tags))[:-1] sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table sql += " WHERE " for _ in keys: sql += _ + "=%s AND " sql = sql[:-4] sql += "LIMIT 1)" arg = list(key_value) arg.extend(tag_value) arg.extend(list(key_value)) rows = self.cursor.execute(sql, args=arg) if rows == 0: sql = "UPDATE " + table + " SET " for _ in tags: sql += _ + "=%s," sql = sql[:-1] sql += " WHERE " for _ in keys: sql += _ + "=%s AND " sql = sql[:-4] arg = [] arg.extend(tag_value) arg.extend(list(key_value)) self.cursor.execute(sql, args=arg) if flag: log.info(sql % tuple(arg)) else: sql = "UPDATE " + table + " SET " for _ in tags: sql += _ + "=%s," sql = sql[:-1] sql += " WHERE " for _ in keys: sql += _ + "=%s AND " sql = sql[:-4] arg = [] arg.extend(tag_value) arg.extend(list(key_value)) self.cursor.execute(sql, args=arg) if flag: log.info(sql % tuple(arg)) def _insert_on_duplicate(self, table, keys, tags, tag_value, flag, key_value): name_all = list(keys) name_all.extend(tags) arg = list(key_value) arg.extend(tag_value) arg.extend(tag_value) sql_name = '(' + ','.join(name_all) + ')' sql_value = '(' + ','.join(['%s'] * len(name_all)) + ')' sql_update = ','.join([_ + '=%s' for _ in tags]) sql = """ insert into %s %s VALUES %s ON duplicate key UPDATE %s """ % (table, sql_name, sql_value, sql_update) self.cursor.execute(sql, args=arg) if flag: log.debug(sql % tuple(arg)) def insertorupdatemany(self, table, keys, tags, tag_values, key_values, flag=False, unique_key=False, update=False): """ :param table: 表名 :param keys: 联合主键名元组 :param tags: 字段名元组 :param tag_values: 字段值组(list or pd.DataFrame) :param key_values: 主键值组(list or pd.DataFrame) :param flag: 控制是否打印日志 :param unique_key: keys 是否为table的 unique_key :return: ps: 效率(外网): rows / 50; 1000以上更新使用 """ if isinstance(tag_values, pd.DataFrame): list_tag_value = [list(tag_values.iloc[_, :]) for _ in range(len(tag_values))] else: list_tag_value = list(tag_values) if isinstance(key_values, pd.DataFrame): list_key_value = [list(key_values.iloc[_, :]) for _ in range(len(key_values))] else: list_key_value = list(key_values) for _ in range(len(list_tag_value)): tag_value = list_tag_value[_] key_value = list_key_value[_] if unique_key: self._insert_on_duplicate(table, keys, tags, tag_value, flag, key_value) else: self._insertorupdate(table, keys, tags, tag_value, flag, key_value, update) self.conn.commit() def _check_repeat_key(self, key_list): tmp = list(map(lambda x: tuple(x), key_list)) if len(tmp) == len(set(tmp)): return False else: last_data = -1 repeat_key = set() for i in sorted(tmp): if last_data == i: repeat_key.add(i) if len(repeat_key) >= 10: break last_data = i log.error('Reject repeated keys') log.error('repeat_key: %s' % repeat_key) return True def _convert_to_list(self, data): if isinstance(data, pd.DataFrame): # np.nan != np.nan 从而判断值为np.nan list_data = [map(lambda x: None if x != x else x, list(data.iloc[_, :])) for _ in range(len(data))] li =[] for i in list_data: li.append(list(i)) list_data = li else: list_data = list(data) return list_data def _get_exist_keys_index(self, table, keys, key_values, flag=False): list_sql_when = [] list_tmp = [] for i in range(len(key_values)): sql_when = """when (%s)=(%s) then %s""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])), i) list_sql_when.append(sql_when) list_tmp.extend(key_values[i]) list_sql_condition = [] for i in range(len(key_values)): # sql_condition_old = """(%s)=(%s)""" % (','.join(keys), ','.join(['%s'] * len(key_values[i]))) row_condition_list = map(lambda x: '%s = %%s' % x, keys) sql_condition = """(%s)""" % ' and '.join(row_condition_list) # print sql_condition_old, sql_condition list_sql_condition.append(sql_condition) list_tmp.extend(key_values[i]) sql_where = ' or '.join(list_sql_condition) sql_case = '\n'.join(list_sql_when) sql = """ select case %s end from %s where %s """ % (sql_case, table, sql_where) if flag: log.info(sql % tuple(list_tmp)) self.cursor.execute(sql, tuple(list_tmp)) print() result = self.cursor.fetchall() return map(lambda x: x[0], result) def insertorupdatemany_v2(self, table, keys, tags, tag_values, key_values, flag=False, split=80): """ 更新插入多条数据(无key时自动插入, 有keys时更新) :param table: 表名 :param keys: 联合主键名元组 :param tags: 字段名元组 :param tag_values: 字段值组(list or pd.DataFrame) :param key_values: 主键值组(list or pd.DataFrame) :param flag: 控制是否打印日志 :param split: 切割阈值 :return: ps: 效率(外网): rows^2 / 50000; rows以split为单位分批更新 """ if not isinstance(tag_values, (tuple, list, pd.DataFrame)): log.error('Type Error') exit(-1) return if len(tag_values) > split: length = len(tag_values) for i in range(0, length, split): start, finish = i, i + split self.insertorupdatemany_v2(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split) return if len(key_values) == 0 or len(tag_values) == 0: log.debug('insert or update 0 rows') return tag_values = self._convert_to_list(tag_values) key_values = self._convert_to_list(key_values) assert self._check_repeat_key(key_values) == False exist_key_index = list(self._get_exist_keys_index(table, keys, key_values, flag)) new_key_index = list(set(range(len(key_values))) - set(exist_key_index)) update_keys = list(map(lambda x: key_values[x], exist_key_index)) update_tags = list(map(lambda x: tag_values[x], exist_key_index)) insert_keys = list(map(lambda x: key_values[x], new_key_index)) insert_tags = list(map(lambda x: tag_values[x], new_key_index)) self.insert_many(table=table, keys=keys, tags=tags, tag_values=insert_tags, key_values=insert_keys, flag=flag) self.update_many(table=table, keys=keys, tags=tags, tag_values=update_tags, key_values=update_keys, flag=flag, split=split) def insertorupdatemany_v3(self, df, table, keys, tags, flag=False, split=80): self.insertorupdatemany_v2( table=table, keys=keys, tags=tags, tag_values=df[tags], key_values=df[keys], flag=flag, split=split ) def _get_s_format(self, data): """ Args: data: [[featureA1, featureB1, ...], [featureA2, featureB2, ...], ...] Returns: format of %s and real value Example: [['2017-07-01', 78], ['2017-07-01', 1]] -> ('((%s, %s), (%s, %s))', ['2017-07-01', 78, '2017-07-01', 1]) """ list_tmp_s = [] values = [] for _ in data: tmp_s = ','.join(len(_) * ['%s']) values.extend(_) if len(_) > 1: tmp_s = '(' + tmp_s + ')' list_tmp_s.append(tmp_s) format_s = '(' + ','.join(list_tmp_s) + ')' return format_s, values def delete_by_key(self, table, keys, key_values, flag=False): """ Args: table: 表名 keys: 联合主键名元组 key_values: 主键值组(list or pd.DataFrame) flag: 控制是否打印日志 Examples: delete_by_key('table_test', keys=['date'], key_values=[['2017-07-01'], ['2017-07-02']], flag=False) delete_by_key('table_test', keys=['date'], key_values=['2017-07-01'], flag=False) """ if len(key_values) == 0: return if not (isinstance(key_values[0], (list, tuple)) or isinstance(key_values, pd.DataFrame)): key_values_list = [key_values] else: key_values_list = self._convert_to_list(key_values) sql_keys = '(' + ','.join(keys) + ')' contact_s, values_s = self._get_s_format(key_values_list) sql_del = """ delete from %s where %s in %s """ % (table, sql_keys, contact_s) if flag: log.debug(sql_del % tuple(values_s)) self.cursor.execute(sql_del, tuple(values_s)) self.conn.commit() def insert_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80): """ 直接插入多条数据 :param table: 表名 :param keys: 联合主键名元组 :param tags: 字段名元组 :param tag_values: 字段值组(list or pd.DataFrame) :param key_values: 主键值组(list or pd.DataFrame) :param flag: 控制是否打印日志 :return: Examples: 参照 insertorupdatemany_v2 insert into table (count_date, cid, tag1, tag2) values ('2017-01-01', 10, 1, 'a'), ('2017-01-02', 20, 2, 'b'), ... """ if len(key_values) == 0 or len(tag_values) == 0: log.debug('insert 0 rows') return if len(tag_values) > split: length = len(tag_values) for i in range(0, length, split): start, finish = i, i + split self.insert_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split) return tag_values = self._convert_to_list(tag_values) key_values = self._convert_to_list(key_values) feature_total = "(" + ",".join(keys + tags) + ")" tmp_s = "(" + ",".join(["%s"] * len(keys + tags)) + ")" tmp_s_concat = ",\n".join([tmp_s] * len(key_values)) sql_insert = """ Insert into %s %s values %s""" % (table, feature_total, tmp_s_concat) value_insert = [] for _ in zip(key_values, tag_values): value_insert.extend(_[0] + _[1]) if flag: log.debug(sql_insert % tuple(value_insert)) t0 = time.time() self.cursor.execute(sql_insert,tuple(value_insert)) log.debug('insert %s rows, cost: %s' % (len(key_values), time.time() - t0)) self.conn.commit() def update_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80): """ 更新多条数据(无key时不会自动插入) :param table: 表名 :param keys: 联合主键名元组 :param tags: 字段名元组 :param tag_values: 字段值组(list or pd.DataFrame) :param key_values: 主键值组(list or pd.DataFrame) :param flag: 控制是否打印日志 :param split: 分批更新量 :return: Examples: 参照 insertorupdatemany_v2 # 单条 update sql tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10) update table set tag1=1, tag2='a' where (count_date, cid) =('2017-01-01', 10) # 多条组合 update sql # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10); # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10); update table set tag1 = case when (count_date, cid)=('2017-01-01', 10) then 1 when (count_date, cid)=('2017-01-02', 20) then 2 ... , tag_2 = case when (count_date, cid)=('2017-01-01', 10) then 'a' when (count_date, cid)=('2017-01-02', 20) then 'b' ... where (count_date, cid)=('2017-01-01', 10) or (count_date, cid)=('2017-01-02', 20) or ... """ if len(tag_values) > split: length = len(tag_values) for i in range(0, length, split): start, finish = i, i + split self.update_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split) return if len(key_values) == 0 or len(tag_values) == 0: log.debug('update 0 rows') return tag_values = self._convert_to_list(tag_values) key_values = self._convert_to_list(key_values) if self._check_repeat_key(key_values): return update_value = [] sql_keys = ','.join(keys) if len(keys) > 1: sql_keys = '(' + sql_keys + ')' sql_key_values = ','.join(['%s'] * len(keys)) if len(keys) > 1: sql_key_values = '(' + sql_key_values + ')' sql_set_list = [] for i in range(len(tags)): sql_when_list = [] for j in range(len(tag_values)): sql_when = """when %s=%s then %s """ % (sql_keys, sql_key_values, '%s') update_value.extend(key_values[j]) update_value.append(tag_values[j][i]) sql_when_list.append(sql_when) sql_when_concat = '\n\t'.join(sql_when_list) sql_set = """%s = case \n\t %s\n end""" % (tags[i], sql_when_concat) sql_set_list.append(sql_set) for _ in key_values: update_value.extend(_) sql_set_concat = ',\n'.join(sql_set_list) list_sql_condition = [] for i in range(len(key_values)): row_condition_list = map(lambda x: '%s = %%s' % x, keys) sql_condition = """(%s)""" % ' and '.join(row_condition_list) list_sql_condition.append(sql_condition) sql_where = ' or '.join(list_sql_condition) # condition = ' or\n\t'.join([sql_keys + '=' + sql_key_values] * len(tag_values)) # print condition sql = """update %s\n set %s\n where %s""" % (table, sql_set_concat, sql_where) if flag: log.info(sql % tuple(update_value)) t0 = time.time() self.cursor.execute(sql, tuple(update_value)) self.conn.commit() log.debug('update %s rows, cost: %s' % (len(key_values), time.time() - t0)) def getColumn(self,table,flag=0): "获取表的所有列" sql="SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` " \ "WHERE `TABLE_NAME`='{}' ORDER BY ordinal_position".format(table) self.cursor.execute(sql) a= self.cursor.fetchall() str='' li=[] for i in a: str+=i[0]+',' li.append(i[0]) if flag: return li else: return str[:-1]