123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652 |
- """
- @desc 数据库操作方法封装
- @auth chenkai
- @date 2020/11/19
- @py_version py3.6
- """
- import pymysql
- import logging as log
- import pandas as pd
- import time
- from model.log import logger
- log = logger()
- pd.set_option('display.max_columns', None)
- pd.set_option('display.width', 1000)
- MYSQL_DEBUG = 1
- class MysqlOperation:
- def __init__(self, host, user, passwd, db, port=3306):
- try:
- self.conn = pymysql.connect(host=host,
- user=user,
- passwd=passwd,
- db=db,
- charset='utf8mb4',
- port=port)
- self.cursor = self.conn.cursor()
- except Exception as e:
- log.info(e)
- def set_dict_cursor(self):
- """
- 设置字典形式取数据
- """
- self.cursor = self.conn.cursor(pymysql.cursors.DictCursor)
-
- def getData(self, sql, args=None):
- """
- :param sql:
- :param args:
- :return: tuple(tuple)
- """
- start = time.time()
- self.cursor.execute(sql, args=args)
- result = self.cursor.fetchall()
- if MYSQL_DEBUG:
- sql_str = sql % tuple(args) if args else sql
- log.info('sql: \n' + sql_str)
- log.info('sql cost: %s' % (time.time() - start))
- return result
- def get_data_list(self,sql,arg=None):
- """
- :param sql:
- :param arg:
- :return: list[list]
- """
- data=self.getData(sql,arg)
- li=[]
- for i in data:
- li.append(list(i))
- return li
- def getDataOneList(self,sql):
- """获取一列"""
- data = self.getData(sql)
- li = []
- for i in data:
- li.append(i[0])
- return li
- def execute(self, sql,data=None):
- start = time.time()
- if data:
- k=self.cursor.execute(sql,data)
- else:
- k=self.cursor.execute(sql)
- self.conn.commit()
- # if MYSQL_DEBUG:
- #
- # # log.info('sql: \n' + sql)
- # log.info('sql cost: %s' % (time.time() - start))
- print(f"affect rows :{k}")
- def executeMany(self,sql,data):
- start = time.time()
- k=self.cursor.executemany(sql,data)
- self.conn.commit()
- # if MYSQL_DEBUG:
- # log.info('sql: \n' + sql)
- # log.info('sql cost: %s' % (time.time() - start))
- print(f"\033[1;36maffect rows :{k} \033[0m")
- def getOne(self,sql, args=None):
- result = self.getData(sql, args)
- return result[0][0]
- def getData_pd(self, sql, args=None):
- start = time.time()
- # if args:
- # log.debug(sql % tuple(args))
- # else:
- # log.debug(sql)
- self.cursor.execute(sql, args=args)
- num_fields = len(self.cursor.description)
- field_names = [i[0] for i in self.cursor.description]
- df = self.cursor.fetchall()
- df = pd.DataFrame(data=list(df), columns=field_names)
- if MYSQL_DEBUG:
- sql_str = sql % tuple(args) if args else sql
- log.info('sql: \n' + sql_str)
- log.info('sql cost: %s' % (time.time() - start))
- return df
- # def insertData(self, sql, args=None):
- # # if args:
- # # log.debug(sql % tuple(args))
- # # else:
- # # log.debug(sql)
- # start = time.time()
- # self.cursor.execute(sql, args=args)
- #
- # if MYSQL_DEBUG:
- # sql_str = sql % tuple(args) if args else sql
- # log.info('sql: \n' + sql_str)
- # log.info('sql cost: %s' % (time.time() - start))
- # self.conn.commit()
- def executeWithoutCommit(self, sql, args=None):
- return self.cursor.execute(sql, args=args)
- def commit(self):
- self.conn.commit()
- def insertorupdate(self, table, keys, tags, tagvalue, flag, *args):
- """
- :param table: 表名
- :param keys: 联合主键名元组
- :param tags: 字段名元组
- :param tagvalue: 字段值
- :param args: 主键值
- :param flag: 控制是否打印日志
- :return:
- """
- # log.info(tags)
- sql = "INSERT INTO " + table + " ("
- sql += ",".join(keys) + ","
- sql += ",".join(tags)
- sql += ") SELECT "
- sql += "%s," * len(keys)
- sql += ("%s," * len(tags))[:-1]
- sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table
- sql += " WHERE "
- for _ in keys:
- sql += _ + "=%s AND "
- sql = sql[:-4]
- sql += "LIMIT 1)"
- arg = list(args)
- arg.extend(tagvalue)
- arg.extend(list(args))
- rows = self.cursor.execute(sql, args=arg)
- if rows == 0:
- sql = "UPDATE " + table + " SET "
- for _ in tags:
- sql += _ + "=%s,"
- sql = sql[:-1]
- sql += " WHERE "
- for _ in keys:
- sql += _ + "=%s AND "
- sql = sql[:-4]
- arg = []
- arg.extend(tagvalue)
- arg.extend(list(args))
- self.cursor.execute(sql, args=arg)
- if flag:
- log.info(sql % tuple(arg))
- self.conn.commit()
- def _insertorupdate(self, table, keys, tags, tag_value, flag, key_value, update=False):
- if not update:
- sql = "INSERT INTO " + table + " ("
- sql += ",".join(keys) + ","
- sql += ",".join(tags)
- sql += ") SELECT "
- sql += "%s," * len(keys)
- sql += ("%s," * len(tags))[:-1]
- sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table
- sql += " WHERE "
- for _ in keys:
- sql += _ + "=%s AND "
- sql = sql[:-4]
- sql += "LIMIT 1)"
- arg = list(key_value)
- arg.extend(tag_value)
- arg.extend(list(key_value))
- rows = self.cursor.execute(sql, args=arg)
- if rows == 0:
- sql = "UPDATE " + table + " SET "
- for _ in tags:
- sql += _ + "=%s,"
- sql = sql[:-1]
- sql += " WHERE "
- for _ in keys:
- sql += _ + "=%s AND "
- sql = sql[:-4]
- arg = []
- arg.extend(tag_value)
- arg.extend(list(key_value))
- self.cursor.execute(sql, args=arg)
- if flag:
- log.info(sql % tuple(arg))
- else:
- sql = "UPDATE " + table + " SET "
- for _ in tags:
- sql += _ + "=%s,"
- sql = sql[:-1]
- sql += " WHERE "
- for _ in keys:
- sql += _ + "=%s AND "
- sql = sql[:-4]
- arg = []
- arg.extend(tag_value)
- arg.extend(list(key_value))
- self.cursor.execute(sql, args=arg)
- if flag:
- log.info(sql % tuple(arg))
- def _insert_on_duplicate(self, table, keys, tags, tag_value, flag, key_value):
- name_all = list(keys)
- name_all.extend(tags)
- arg = list(key_value)
- arg.extend(tag_value)
- arg.extend(tag_value)
- sql_name = '(' + ','.join(name_all) + ')'
- sql_value = '(' + ','.join(['%s'] * len(name_all)) + ')'
- sql_update = ','.join([_ + '=%s' for _ in tags])
- sql = """
- insert into %s
- %s
- VALUES %s
- ON duplicate key UPDATE %s
- """ % (table, sql_name, sql_value, sql_update)
- self.cursor.execute(sql, args=arg)
- if flag:
- log.debug(sql % tuple(arg))
- def insertorupdatemany(self, table, keys, tags, tag_values, key_values, flag=False, unique_key=False, update=False):
- """
- :param table: 表名
- :param keys: 联合主键名元组
- :param tags: 字段名元组
- :param tag_values: 字段值组(list or pd.DataFrame)
- :param key_values: 主键值组(list or pd.DataFrame)
- :param flag: 控制是否打印日志
- :param unique_key: keys 是否为table的 unique_key
- :return:
- ps: 效率(外网): rows / 50; 1000以上更新使用
- """
- if isinstance(tag_values, pd.DataFrame):
- list_tag_value = [list(tag_values.iloc[_, :]) for _ in range(len(tag_values))]
- else:
- list_tag_value = list(tag_values)
- if isinstance(key_values, pd.DataFrame):
- list_key_value = [list(key_values.iloc[_, :]) for _ in range(len(key_values))]
- else:
- list_key_value = list(key_values)
- for _ in range(len(list_tag_value)):
- tag_value = list_tag_value[_]
- key_value = list_key_value[_]
- if unique_key:
- self._insert_on_duplicate(table, keys, tags, tag_value, flag, key_value)
- else:
- self._insertorupdate(table, keys, tags, tag_value, flag, key_value, update)
- self.conn.commit()
- def _check_repeat_key(self, key_list):
- tmp = list(map(lambda x: tuple(x), key_list))
- if len(tmp) == len(set(tmp)):
- return False
- else:
- last_data = -1
- repeat_key = set()
- for i in sorted(tmp):
- if last_data == i:
- repeat_key.add(i)
- if len(repeat_key) >= 10:
- break
- last_data = i
- log.error('Reject repeated keys')
- log.error('repeat_key: %s' % repeat_key)
- return True
- def _convert_to_list(self, data):
- if isinstance(data, pd.DataFrame):
- # np.nan != np.nan 从而判断值为np.nan
- list_data = [map(lambda x: None if x != x else x, list(data.iloc[_, :])) for _ in range(len(data))]
- li =[]
- for i in list_data:
- li.append(list(i))
- list_data = li
- else:
- list_data = list(data)
- return list_data
- def _get_exist_keys_index(self, table, keys, key_values, flag=False):
- list_sql_when = []
- list_tmp = []
- for i in range(len(key_values)):
- sql_when = """when (%s)=(%s) then %s""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])), i)
- list_sql_when.append(sql_when)
- list_tmp.extend(key_values[i])
- list_sql_condition = []
- for i in range(len(key_values)):
- # sql_condition_old = """(%s)=(%s)""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])))
- row_condition_list = map(lambda x: '%s = %%s' % x, keys)
- sql_condition = """(%s)""" % ' and '.join(row_condition_list)
- # print sql_condition_old, sql_condition
- list_sql_condition.append(sql_condition)
- list_tmp.extend(key_values[i])
- sql_where = ' or '.join(list_sql_condition)
- sql_case = '\n'.join(list_sql_when)
- sql = """
- select
- case
- %s
- end
- from %s
- where %s
- """ % (sql_case, table, sql_where)
- if flag:
- log.info(sql % tuple(list_tmp))
- self.cursor.execute(sql, tuple(list_tmp))
- print()
- result = self.cursor.fetchall()
- return map(lambda x: x[0], result)
- def insertorupdatemany_v2(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
- """
- 更新插入多条数据(无key时自动插入, 有keys时更新)
- :param table: 表名
- :param keys: 联合主键名元组
- :param tags: 字段名元组
- :param tag_values: 字段值组(list or pd.DataFrame)
- :param key_values: 主键值组(list or pd.DataFrame)
- :param flag: 控制是否打印日志
- :param split: 切割阈值
- :return:
- ps: 效率(外网): rows^2 / 50000; rows以split为单位分批更新
- """
- if not isinstance(tag_values, (tuple, list, pd.DataFrame)):
- log.error('Type Error')
- exit(-1)
- return
- if len(tag_values) > split:
- length = len(tag_values)
- for i in range(0, length, split):
- start, finish = i, i + split
- self.insertorupdatemany_v2(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
- return
- if len(key_values) == 0 or len(tag_values) == 0:
- log.debug('insert or update 0 rows')
- return
- tag_values = self._convert_to_list(tag_values)
- key_values = self._convert_to_list(key_values)
- assert self._check_repeat_key(key_values) == False
- exist_key_index = list(self._get_exist_keys_index(table, keys, key_values, flag))
- new_key_index = list(set(range(len(key_values))) - set(exist_key_index))
- update_keys = list(map(lambda x: key_values[x], exist_key_index))
- update_tags = list(map(lambda x: tag_values[x], exist_key_index))
- insert_keys = list(map(lambda x: key_values[x], new_key_index))
- insert_tags = list(map(lambda x: tag_values[x], new_key_index))
- self.insert_many(table=table,
- keys=keys,
- tags=tags,
- tag_values=insert_tags,
- key_values=insert_keys,
- flag=flag)
- self.update_many(table=table,
- keys=keys,
- tags=tags,
- tag_values=update_tags,
- key_values=update_keys,
- flag=flag,
- split=split)
- def insertorupdatemany_v3(self, df, table, keys, tags, flag=False, split=80):
- self.insertorupdatemany_v2(
- table=table,
- keys=keys,
- tags=tags,
- tag_values=df[tags],
- key_values=df[keys],
- flag=flag,
- split=split
- )
- def dfsave2mysql(self, df, table, key, tag):
- self.insertorupdatemany_v2(
- table=table,
- keys=key,
- tags=tag,
- tag_values=df[tag],
- key_values=df[key]
- )
- def _get_s_format(self, data):
- """
- Args:
- data: [[featureA1, featureB1, ...], [featureA2, featureB2, ...], ...]
- Returns:
- format of %s and real value
- Example:
- [['2017-07-01', 78], ['2017-07-01', 1]] ->
- ('((%s, %s), (%s, %s))', ['2017-07-01', 78, '2017-07-01', 1])
- """
- list_tmp_s = []
- values = []
- for _ in data:
- tmp_s = ','.join(len(_) * ['%s'])
- values.extend(_)
- if len(_) > 1:
- tmp_s = '(' + tmp_s + ')'
- list_tmp_s.append(tmp_s)
- format_s = '(' + ','.join(list_tmp_s) + ')'
- return format_s, values
- def delete_by_key(self, table, keys, key_values, flag=False):
- """
- Args:
- table: 表名
- keys: 联合主键名元组
- key_values: 主键值组(list or pd.DataFrame)
- flag: 控制是否打印日志
- Examples:
- delete_by_key('table_test', keys=['date'], key_values=[['2017-07-01'], ['2017-07-02']], flag=False)
- delete_by_key('table_test', keys=['date'], key_values=['2017-07-01'], flag=False)
- """
- if len(key_values) == 0:
- return
- if not (isinstance(key_values[0], (list, tuple)) or isinstance(key_values, pd.DataFrame)):
- key_values_list = [key_values]
- else:
- key_values_list = self._convert_to_list(key_values)
- sql_keys = '(' + ','.join(keys) + ')'
- contact_s, values_s = self._get_s_format(key_values_list)
- sql_del = """
- delete from %s
- where %s in %s
- """ % (table, sql_keys, contact_s)
- if flag:
- log.debug(sql_del % tuple(values_s))
- self.cursor.execute(sql_del, tuple(values_s))
- self.conn.commit()
- def insert_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
- """
- 直接插入多条数据
- :param table: 表名
- :param keys: 联合主键名元组
- :param tags: 字段名元组
- :param tag_values: 字段值组(list or pd.DataFrame)
- :param key_values: 主键值组(list or pd.DataFrame)
- :param flag: 控制是否打印日志
- :return:
- Examples: 参照 insertorupdatemany_v2
- insert into table
- (count_date, cid, tag1, tag2)
- values ('2017-01-01', 10, 1, 'a'), ('2017-01-02', 20, 2, 'b'), ...
- """
- if len(key_values) == 0 or len(tag_values) == 0:
- log.debug('insert 0 rows')
- return
- if len(tag_values) > split:
- length = len(tag_values)
- for i in range(0, length, split):
- start, finish = i, i + split
- self.insert_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
- return
- tag_values = self._convert_to_list(tag_values)
- key_values = self._convert_to_list(key_values)
- feature_total = "(" + ",".join(keys + tags) + ")"
- tmp_s = "(" + ",".join(["%s"] * len(keys + tags)) + ")"
- tmp_s_concat = ",\n".join([tmp_s] * len(key_values))
- sql_insert = """
- Insert into %s
- %s
- values %s""" % (table, feature_total, tmp_s_concat)
- value_insert = []
- for _ in zip(key_values, tag_values):
- value_insert.extend(_[0] + _[1])
- if flag:
- log.debug(sql_insert % tuple(value_insert))
- t0 = time.time()
- self.cursor.execute(sql_insert,tuple(value_insert))
- log.debug('insert %s rows, cost: %s' % (len(key_values), round(time.time() - t0, 2)))
- self.conn.commit()
- def update_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
- """
- 更新多条数据(无key时不会自动插入)
- :param table: 表名
- :param keys: 联合主键名元组
- :param tags: 字段名元组
- :param tag_values: 字段值组(list or pd.DataFrame)
- :param key_values: 主键值组(list or pd.DataFrame)
- :param flag: 控制是否打印日志
- :param split: 分批更新量
- :return:
- Examples: 参照 insertorupdatemany_v2
- # 单条 update sql tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10)
- update table
- set tag1=1, tag2='a'
- where (count_date, cid) =('2017-01-01', 10)
- # 多条组合 update sql
- # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10);
- # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10);
- update table
- set tag1 = case
- when (count_date, cid)=('2017-01-01', 10) then 1
- when (count_date, cid)=('2017-01-02', 20) then 2
- ...
- ,
- tag_2 = case
- when (count_date, cid)=('2017-01-01', 10) then 'a'
- when (count_date, cid)=('2017-01-02', 20) then 'b'
- ...
- where (count_date, cid)=('2017-01-01', 10) or (count_date, cid)=('2017-01-02', 20) or ...
- """
- if len(tag_values) > split:
- length = len(tag_values)
- for i in range(0, length, split):
- start, finish = i, i + split
- self.update_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
- return
- if len(key_values) == 0 or len(tag_values) == 0:
- log.debug('update 0 rows')
- return
- tag_values = self._convert_to_list(tag_values)
- key_values = self._convert_to_list(key_values)
- if self._check_repeat_key(key_values):
- return
- update_value = []
- sql_keys = ','.join(keys)
- if len(keys) > 1:
- sql_keys = '(' + sql_keys + ')'
- sql_key_values = ','.join(['%s'] * len(keys))
- if len(keys) > 1:
- sql_key_values = '(' + sql_key_values + ')'
- sql_set_list = []
- for i in range(len(tags)):
- sql_when_list = []
- for j in range(len(tag_values)):
- sql_when = """when %s=%s then %s """ % (sql_keys, sql_key_values, '%s')
- update_value.extend(key_values[j])
- update_value.append(tag_values[j][i])
- sql_when_list.append(sql_when)
- sql_when_concat = '\n\t'.join(sql_when_list)
- sql_set = """%s = case \n\t %s\n end""" % (tags[i], sql_when_concat)
- sql_set_list.append(sql_set)
- for _ in key_values:
- update_value.extend(_)
- sql_set_concat = ',\n'.join(sql_set_list)
- list_sql_condition = []
- for i in range(len(key_values)):
- row_condition_list = map(lambda x: '%s = %%s' % x, keys)
- sql_condition = """(%s)""" % ' and '.join(row_condition_list)
- list_sql_condition.append(sql_condition)
- sql_where = ' or '.join(list_sql_condition)
- # condition = ' or\n\t'.join([sql_keys + '=' + sql_key_values] * len(tag_values))
- # print condition
- sql = """update %s\n set %s\n where %s""" % (table, sql_set_concat, sql_where)
- if flag:
- log.info(sql % tuple(update_value))
- t0 = time.time()
- self.cursor.execute(sql, tuple(update_value))
- self.conn.commit()
- log.debug('update %s rows, cost: %s' % (len(key_values), round(time.time() - t0, 2)))
- def getColumn(self,table,flag=0):
- "获取表的所有列"
- sql="SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` " \
- "WHERE `TABLE_NAME`='{}' ORDER BY ordinal_position".format(table)
- self.cursor.execute(sql)
- a= self.cursor.fetchall()
- str=''
- li=[]
- for i in a:
- str+=i[0]+','
- li.append(i[0])
- if flag:
- return li
- else:
- return str[:-1]
|