DataBaseOperation.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. """
  2. @desc 数据库操作方法封装
  3. @auth chenkai
  4. @date 2020/11/19
  5. @py_version py3.6
  6. """
  7. import pymysql
  8. import logging as log
  9. import pandas as pd
  10. import time
  11. from model.log import logger
  12. log = logger()
  13. pd.set_option('display.max_columns', None)
  14. pd.set_option('display.width', 1000)
  15. MYSQL_DEBUG = 1
  16. class MysqlOperation:
  17. def __init__(self, host, user, passwd, db, port=3306):
  18. try:
  19. self.conn = pymysql.connect(host=host,
  20. user=user,
  21. passwd=passwd,
  22. db=db,
  23. charset='utf8mb4',
  24. port=port)
  25. self.cursor = self.conn.cursor()
  26. except Exception as e:
  27. log.info(e)
  28. def set_dict_cursor(self):
  29. """
  30. 设置字典形式取数据
  31. """
  32. self.cursor = self.conn.cursor(pymysql.cursors.DictCursor)
  33. def getData(self, sql, args=None):
  34. """
  35. :param sql:
  36. :param args:
  37. :return: tuple(tuple)
  38. """
  39. start = time.time()
  40. self.cursor.execute(sql, args=args)
  41. result = self.cursor.fetchall()
  42. if MYSQL_DEBUG:
  43. sql_str = sql % tuple(args) if args else sql
  44. log.info('sql: \n' + sql_str)
  45. log.info('sql cost: %s' % (time.time() - start))
  46. return result
  47. def get_data_list(self,sql,arg=None):
  48. """
  49. :param sql:
  50. :param arg:
  51. :return: list[list]
  52. """
  53. data=self.getData(sql,arg)
  54. li=[]
  55. for i in data:
  56. li.append(list(i))
  57. return li
  58. def get_data_dict(self,sql):
  59. data =self.getData(sql)
  60. di={}
  61. for i in data:
  62. di[i[0]]=i[1]
  63. return di
  64. def execute(self, sql,args=None):
  65. start = time.time()
  66. self.cursor.execute(query=sql,args=args)
  67. self.conn.commit()
  68. if MYSQL_DEBUG:
  69. log.info('sql: \n' + sql)
  70. log.info('sql cost: %s' % (time.time() - start))
  71. def getOne(self,sql, args=None):
  72. result = self.getData(sql, args)
  73. return result[0][0]
  74. def getOneList(self, sql, args=None):
  75. result = self.getData(sql, args)
  76. return [x[0] for x in result]
  77. def getData_pd(self, sql, args=None):
  78. start = time.time()
  79. # if args:
  80. # log.debug(sql % tuple(args))
  81. # else:
  82. # log.debug(sql)
  83. self.cursor.execute(sql, args=args)
  84. num_fields = len(self.cursor.description)
  85. field_names = [i[0] for i in self.cursor.description]
  86. df = self.cursor.fetchall()
  87. df = pd.DataFrame(data=list(df), columns=field_names)
  88. if MYSQL_DEBUG:
  89. sql_str = sql % tuple(args) if args else sql
  90. log.info('sql: \n' + sql_str)
  91. log.info('sql cost: %s' % (time.time() - start))
  92. return df
  93. def getData_json(self,sql):
  94. """
  95. :param sql:
  96. :return: [{},{}]
  97. """
  98. self.cursor.execute(sql)
  99. num_fields = len(self.cursor.description)
  100. field_names = [i[0] for i in self.cursor.description]
  101. df = self.cursor.fetchall()
  102. li=[]
  103. for i in list(df):
  104. li.append(dict(zip(field_names,i)))
  105. return li
  106. def insertData(self, sql, args=None):
  107. # if args:
  108. # log.debug(sql % tuple(args))
  109. # else:
  110. # log.debug(sql)
  111. start = time.time()
  112. self.cursor.execute(sql, args=args)
  113. if MYSQL_DEBUG:
  114. sql_str = sql % tuple(args) if args else sql
  115. log.info('sql: \n' + sql_str)
  116. log.info('sql cost: %s' % (time.time() - start))
  117. self.conn.commit()
  118. def executeWithoutCommit(self, sql, args=None):
  119. return self.cursor.execute(sql, args=args)
  120. def commit(self):
  121. self.conn.commit()
  122. def insertorupdate(self, table, keys, tags, tagvalue, flag, *args):
  123. """
  124. :param table: 表名
  125. :param keys: 联合主键名元组
  126. :param tags: 字段名元组
  127. :param tagvalue: 字段值
  128. :param args: 主键值
  129. :param flag: 控制是否打印日志
  130. :return:
  131. """
  132. # log.info(tags)
  133. sql = "INSERT INTO " + table + " ("
  134. sql += ",".join(keys) + ","
  135. sql += ",".join(tags)
  136. sql += ") SELECT "
  137. sql += "%s," * len(keys)
  138. sql += ("%s," * len(tags))[:-1]
  139. sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table
  140. sql += " WHERE "
  141. for _ in keys:
  142. sql += _ + "=%s AND "
  143. sql = sql[:-4]
  144. sql += "LIMIT 1)"
  145. arg = list(args)
  146. arg.extend(tagvalue)
  147. arg.extend(list(args))
  148. rows = self.cursor.execute(sql, args=arg)
  149. if rows == 0:
  150. sql = "UPDATE " + table + " SET "
  151. for _ in tags:
  152. sql += _ + "=%s,"
  153. sql = sql[:-1]
  154. sql += " WHERE "
  155. for _ in keys:
  156. sql += _ + "=%s AND "
  157. sql = sql[:-4]
  158. arg = []
  159. arg.extend(tagvalue)
  160. arg.extend(list(args))
  161. self.cursor.execute(sql, args=arg)
  162. if flag:
  163. log.info(sql % tuple(arg))
  164. self.conn.commit()
  165. def _insertorupdate(self, table, keys, tags, tag_value, flag, key_value, update=False):
  166. if not update:
  167. sql = "INSERT INTO " + table + " ("
  168. sql += ",".join(keys) + ","
  169. sql += ",".join(tags)
  170. sql += ") SELECT "
  171. sql += "%s," * len(keys)
  172. sql += ("%s," * len(tags))[:-1]
  173. sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table
  174. sql += " WHERE "
  175. for _ in keys:
  176. sql += _ + "=%s AND "
  177. sql = sql[:-4]
  178. sql += "LIMIT 1)"
  179. arg = list(key_value)
  180. arg.extend(tag_value)
  181. arg.extend(list(key_value))
  182. rows = self.cursor.execute(sql, args=arg)
  183. if rows == 0:
  184. sql = "UPDATE " + table + " SET "
  185. for _ in tags:
  186. sql += _ + "=%s,"
  187. sql = sql[:-1]
  188. sql += " WHERE "
  189. for _ in keys:
  190. sql += _ + "=%s AND "
  191. sql = sql[:-4]
  192. arg = []
  193. arg.extend(tag_value)
  194. arg.extend(list(key_value))
  195. self.cursor.execute(sql, args=arg)
  196. if flag:
  197. log.info(sql % tuple(arg))
  198. else:
  199. sql = "UPDATE " + table + " SET "
  200. for _ in tags:
  201. sql += _ + "=%s,"
  202. sql = sql[:-1]
  203. sql += " WHERE "
  204. for _ in keys:
  205. sql += _ + "=%s AND "
  206. sql = sql[:-4]
  207. arg = []
  208. arg.extend(tag_value)
  209. arg.extend(list(key_value))
  210. self.cursor.execute(sql, args=arg)
  211. if flag:
  212. log.info(sql % tuple(arg))
  213. def _insert_on_duplicate(self, table, keys, tags, tag_value, flag, key_value):
  214. name_all = list(keys)
  215. name_all.extend(tags)
  216. arg = list(key_value)
  217. arg.extend(tag_value)
  218. arg.extend(tag_value)
  219. sql_name = '(' + ','.join(name_all) + ')'
  220. sql_value = '(' + ','.join(['%s'] * len(name_all)) + ')'
  221. sql_update = ','.join([_ + '=%s' for _ in tags])
  222. sql = """
  223. insert into %s
  224. %s
  225. VALUES %s
  226. ON duplicate key UPDATE %s
  227. """ % (table, sql_name, sql_value, sql_update)
  228. self.cursor.execute(sql, args=arg)
  229. if flag:
  230. log.debug(sql % tuple(arg))
  231. def insertorupdatemany(self, table, keys, tags, tag_values, key_values, flag=False, unique_key=False, update=False):
  232. """
  233. :param table: 表名
  234. :param keys: 联合主键名元组
  235. :param tags: 字段名元组
  236. :param tag_values: 字段值组(list or pd.DataFrame)
  237. :param key_values: 主键值组(list or pd.DataFrame)
  238. :param flag: 控制是否打印日志
  239. :param unique_key: keys 是否为table的 unique_key
  240. :return:
  241. ps: 效率(外网): rows / 50; 1000以上更新使用
  242. """
  243. if isinstance(tag_values, pd.DataFrame):
  244. list_tag_value = [list(tag_values.iloc[_, :]) for _ in range(len(tag_values))]
  245. else:
  246. list_tag_value = list(tag_values)
  247. if isinstance(key_values, pd.DataFrame):
  248. list_key_value = [list(key_values.iloc[_, :]) for _ in range(len(key_values))]
  249. else:
  250. list_key_value = list(key_values)
  251. for _ in range(len(list_tag_value)):
  252. tag_value = list_tag_value[_]
  253. key_value = list_key_value[_]
  254. if unique_key:
  255. self._insert_on_duplicate(table, keys, tags, tag_value, flag, key_value)
  256. else:
  257. self._insertorupdate(table, keys, tags, tag_value, flag, key_value, update)
  258. self.conn.commit()
  259. def _check_repeat_key(self, key_list):
  260. tmp = list(map(lambda x: tuple(x), key_list))
  261. if len(tmp) == len(set(tmp)):
  262. return False
  263. else:
  264. last_data = -1
  265. repeat_key = set()
  266. for i in sorted(tmp):
  267. if last_data == i:
  268. repeat_key.add(i)
  269. if len(repeat_key) >= 10:
  270. break
  271. last_data = i
  272. log.error('Reject repeated keys')
  273. log.error('repeat_key: %s' % repeat_key)
  274. return True
  275. def _convert_to_list(self, data):
  276. if isinstance(data, pd.DataFrame):
  277. # np.nan != np.nan 从而判断值为np.nan
  278. list_data = [map(lambda x: None if x != x else x, list(data.iloc[_, :])) for _ in range(len(data))]
  279. li =[]
  280. for i in list_data:
  281. li.append(list(i))
  282. list_data = li
  283. else:
  284. list_data = list(data)
  285. return list_data
  286. def _get_exist_keys_index(self, table, keys, key_values, flag=False):
  287. list_sql_when = []
  288. list_tmp = []
  289. for i in range(len(key_values)):
  290. sql_when = """when (%s)=(%s) then %s""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])), i)
  291. list_sql_when.append(sql_when)
  292. list_tmp.extend(key_values[i])
  293. list_sql_condition = []
  294. for i in range(len(key_values)):
  295. # sql_condition_old = """(%s)=(%s)""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])))
  296. row_condition_list = map(lambda x: '%s = %%s' % x, keys)
  297. sql_condition = """(%s)""" % ' and '.join(row_condition_list)
  298. # print sql_condition_old, sql_condition
  299. list_sql_condition.append(sql_condition)
  300. list_tmp.extend(key_values[i])
  301. sql_where = ' or '.join(list_sql_condition)
  302. sql_case = '\n'.join(list_sql_when)
  303. sql = """
  304. select
  305. case
  306. %s
  307. end
  308. from %s
  309. where %s
  310. """ % (sql_case, table, sql_where)
  311. if flag:
  312. log.info(sql % tuple(list_tmp))
  313. self.cursor.execute(sql, tuple(list_tmp))
  314. print()
  315. result = self.cursor.fetchall()
  316. return map(lambda x: x[0], result)
  317. def insertorupdatemany_v2(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
  318. """
  319. 更新插入多条数据(无key时自动插入, 有keys时更新)
  320. :param table: 表名
  321. :param keys: 联合主键名元组
  322. :param tags: 字段名元组
  323. :param tag_values: 字段值组(list or pd.DataFrame)
  324. :param key_values: 主键值组(list or pd.DataFrame)
  325. :param flag: 控制是否打印日志
  326. :param split: 切割阈值
  327. :return:
  328. ps: 效率(外网): rows^2 / 50000; rows以split为单位分批更新
  329. """
  330. if not isinstance(tag_values, (tuple, list, pd.DataFrame)):
  331. log.error('Type Error')
  332. exit(-1)
  333. return
  334. if len(tag_values) > split:
  335. length = len(tag_values)
  336. for i in range(0, length, split):
  337. start, finish = i, i + split
  338. self.insertorupdatemany_v2(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
  339. return
  340. if len(key_values) == 0 or len(tag_values) == 0:
  341. log.debug('insert or update 0 rows')
  342. return
  343. tag_values = self._convert_to_list(tag_values)
  344. key_values = self._convert_to_list(key_values)
  345. assert self._check_repeat_key(key_values) == False
  346. exist_key_index = list(self._get_exist_keys_index(table, keys, key_values, flag))
  347. new_key_index = list(set(range(len(key_values))) - set(exist_key_index))
  348. update_keys = list(map(lambda x: key_values[x], exist_key_index))
  349. update_tags = list(map(lambda x: tag_values[x], exist_key_index))
  350. insert_keys = list(map(lambda x: key_values[x], new_key_index))
  351. insert_tags = list(map(lambda x: tag_values[x], new_key_index))
  352. self.insert_many(table=table,
  353. keys=keys,
  354. tags=tags,
  355. tag_values=insert_tags,
  356. key_values=insert_keys,
  357. flag=flag)
  358. self.update_many(table=table,
  359. keys=keys,
  360. tags=tags,
  361. tag_values=update_tags,
  362. key_values=update_keys,
  363. flag=flag,
  364. split=split)
  365. def insertorupdatemany_v3(self, df, table, keys, tags, flag=False, split=80):
  366. self.insertorupdatemany_v2(
  367. table=table,
  368. keys=keys,
  369. tags=tags,
  370. tag_values=df[tags],
  371. key_values=df[keys],
  372. flag=flag,
  373. split=split
  374. )
  375. def _get_s_format(self, data):
  376. """
  377. Args:
  378. data: [[featureA1, featureB1, ...], [featureA2, featureB2, ...], ...]
  379. Returns:
  380. format of %s and real value
  381. Example:
  382. [['2017-07-01', 78], ['2017-07-01', 1]] ->
  383. ('((%s, %s), (%s, %s))', ['2017-07-01', 78, '2017-07-01', 1])
  384. """
  385. list_tmp_s = []
  386. values = []
  387. for _ in data:
  388. tmp_s = ','.join(len(_) * ['%s'])
  389. values.extend(_)
  390. if len(_) > 1:
  391. tmp_s = '(' + tmp_s + ')'
  392. list_tmp_s.append(tmp_s)
  393. format_s = '(' + ','.join(list_tmp_s) + ')'
  394. return format_s, values
  395. def delete_by_key(self, table, keys, key_values, flag=False):
  396. """
  397. Args:
  398. table: 表名
  399. keys: 联合主键名元组
  400. key_values: 主键值组(list or pd.DataFrame)
  401. flag: 控制是否打印日志
  402. Examples:
  403. delete_by_key('table_test', keys=['date'], key_values=[['2017-07-01'], ['2017-07-02']], flag=False)
  404. delete_by_key('table_test', keys=['date'], key_values=['2017-07-01'], flag=False)
  405. """
  406. if len(key_values) == 0:
  407. return
  408. if not (isinstance(key_values[0], (list, tuple)) or isinstance(key_values, pd.DataFrame)):
  409. key_values_list = [key_values]
  410. else:
  411. key_values_list = self._convert_to_list(key_values)
  412. sql_keys = '(' + ','.join(keys) + ')'
  413. contact_s, values_s = self._get_s_format(key_values_list)
  414. sql_del = """
  415. delete from %s
  416. where %s in %s
  417. """ % (table, sql_keys, contact_s)
  418. if flag:
  419. log.debug(sql_del % tuple(values_s))
  420. self.cursor.execute(sql_del, tuple(values_s))
  421. self.conn.commit()
  422. def insert_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
  423. """
  424. 直接插入多条数据
  425. :param table: 表名
  426. :param keys: 联合主键名元组
  427. :param tags: 字段名元组
  428. :param tag_values: 字段值组(list or pd.DataFrame)
  429. :param key_values: 主键值组(list or pd.DataFrame)
  430. :param flag: 控制是否打印日志
  431. :return:
  432. Examples: 参照 insertorupdatemany_v2
  433. insert into table
  434. (count_date, cid, tag1, tag2)
  435. values ('2017-01-01', 10, 1, 'a'), ('2017-01-02', 20, 2, 'b'), ...
  436. """
  437. if len(key_values) == 0 or len(tag_values) == 0:
  438. log.debug('insert 0 rows')
  439. return
  440. if len(tag_values) > split:
  441. length = len(tag_values)
  442. for i in range(0, length, split):
  443. start, finish = i, i + split
  444. self.insert_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
  445. return
  446. tag_values = self._convert_to_list(tag_values)
  447. key_values = self._convert_to_list(key_values)
  448. feature_total = "(" + ",".join(keys + tags) + ")"
  449. tmp_s = "(" + ",".join(["%s"] * len(keys + tags)) + ")"
  450. tmp_s_concat = ",\n".join([tmp_s] * len(key_values))
  451. sql_insert = """
  452. Insert into %s
  453. %s
  454. values %s""" % (table, feature_total, tmp_s_concat)
  455. value_insert = []
  456. for _ in zip(key_values, tag_values):
  457. value_insert.extend(_[0] + _[1])
  458. if flag:
  459. log.debug(sql_insert % tuple(value_insert))
  460. t0 = time.time()
  461. self.cursor.execute(sql_insert,tuple(value_insert))
  462. log.debug('insert %s rows, cost: %s' % (len(key_values), time.time() - t0))
  463. self.conn.commit()
  464. def update_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
  465. """
  466. 更新多条数据(无key时不会自动插入)
  467. :param table: 表名
  468. :param keys: 联合主键名元组
  469. :param tags: 字段名元组
  470. :param tag_values: 字段值组(list or pd.DataFrame)
  471. :param key_values: 主键值组(list or pd.DataFrame)
  472. :param flag: 控制是否打印日志
  473. :param split: 分批更新量
  474. :return:
  475. Examples: 参照 insertorupdatemany_v2
  476. # 单条 update sql tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10)
  477. update table
  478. set tag1=1, tag2='a'
  479. where (count_date, cid) =('2017-01-01', 10)
  480. # 多条组合 update sql
  481. # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10);
  482. # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10);
  483. update table
  484. set tag1 = case
  485. when (count_date, cid)=('2017-01-01', 10) then 1
  486. when (count_date, cid)=('2017-01-02', 20) then 2
  487. ...
  488. ,
  489. tag_2 = case
  490. when (count_date, cid)=('2017-01-01', 10) then 'a'
  491. when (count_date, cid)=('2017-01-02', 20) then 'b'
  492. ...
  493. where (count_date, cid)=('2017-01-01', 10) or (count_date, cid)=('2017-01-02', 20) or ...
  494. """
  495. if len(tag_values) > split:
  496. length = len(tag_values)
  497. for i in range(0, length, split):
  498. start, finish = i, i + split
  499. self.update_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
  500. return
  501. if len(key_values) == 0 or len(tag_values) == 0:
  502. log.debug('update 0 rows')
  503. return
  504. tag_values = self._convert_to_list(tag_values)
  505. key_values = self._convert_to_list(key_values)
  506. if self._check_repeat_key(key_values):
  507. return
  508. update_value = []
  509. sql_keys = ','.join(keys)
  510. if len(keys) > 1:
  511. sql_keys = '(' + sql_keys + ')'
  512. sql_key_values = ','.join(['%s'] * len(keys))
  513. if len(keys) > 1:
  514. sql_key_values = '(' + sql_key_values + ')'
  515. sql_set_list = []
  516. for i in range(len(tags)):
  517. sql_when_list = []
  518. for j in range(len(tag_values)):
  519. sql_when = """when %s=%s then %s """ % (sql_keys, sql_key_values, '%s')
  520. update_value.extend(key_values[j])
  521. update_value.append(tag_values[j][i])
  522. sql_when_list.append(sql_when)
  523. sql_when_concat = '\n\t'.join(sql_when_list)
  524. sql_set = """%s = case \n\t %s\n end""" % (tags[i], sql_when_concat)
  525. sql_set_list.append(sql_set)
  526. for _ in key_values:
  527. update_value.extend(_)
  528. sql_set_concat = ',\n'.join(sql_set_list)
  529. list_sql_condition = []
  530. for i in range(len(key_values)):
  531. row_condition_list = map(lambda x: '%s = %%s' % x, keys)
  532. sql_condition = """(%s)""" % ' and '.join(row_condition_list)
  533. list_sql_condition.append(sql_condition)
  534. sql_where = ' or '.join(list_sql_condition)
  535. # condition = ' or\n\t'.join([sql_keys + '=' + sql_key_values] * len(tag_values))
  536. # print condition
  537. sql = """update %s\n set %s\n where %s""" % (table, sql_set_concat, sql_where)
  538. if flag:
  539. log.info(sql % tuple(update_value))
  540. t0 = time.time()
  541. self.cursor.execute(sql, tuple(update_value))
  542. self.conn.commit()
  543. log.debug('update %s rows, cost: %s' % (len(key_values), time.time() - t0))
  544. def getColumn(self,table,flag=0):
  545. "获取表的所有列"
  546. sql="SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` " \
  547. "WHERE `TABLE_NAME`='{}' ORDER BY ordinal_position".format(table)
  548. self.cursor.execute(sql)
  549. a= self.cursor.fetchall()
  550. str=''
  551. li=[]
  552. for i in a:
  553. str+=i[0]+','
  554. li.append(i[0])
  555. if flag:
  556. return li
  557. else:
  558. return str[:-1]