DataBaseOperation.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. """
  2. @desc 数据库操作方法封装
  3. @auth chenkai
  4. @date 2020/11/19
  5. @py_version py3.7
  6. """
  7. import pymysql
  8. # from clickhouse_sqlalchemy import make_session
  9. # from sqlalchemy import create_engine
  10. import logging as log
  11. import pandas as pd
  12. import time
  13. from model.common.log import logger
  14. log = logger()
  15. pd.set_option('display.max_columns', None)
  16. pd.set_option('display.width', 1000)
  17. MYSQL_DEBUG = 1
  18. class MysqlOperation:
  19. def __init__(self, host, user, passwd, db, port=3306):
  20. try:
  21. self.conn = pymysql.connect(host=host,
  22. user=user,
  23. passwd=passwd,
  24. db=db,
  25. charset='utf8mb4',
  26. port=port)
  27. self.cursor = self.conn.cursor()
  28. except Exception as e:
  29. log.info(e)
  30. def set_dict_cursor(self):
  31. """
  32. 设置字典形式取数据
  33. """
  34. self.cursor = self.conn.cursor(pymysql.cursors.DictCursor)
  35. def getData(self, sql, args=None):
  36. start = time.time()
  37. # if args:
  38. # log.debug(sql % tuple(args))
  39. # else:
  40. # log.debug(sql)
  41. self.cursor.execute(sql, args=args)
  42. result = self.cursor.fetchall()
  43. if MYSQL_DEBUG:
  44. sql_str = sql % tuple(args) if args else sql
  45. log.info('sql: \n' + sql_str)
  46. log.info('sql cost: %s' % (time.time() - start))
  47. return result
  48. def execute(self, sql):
  49. start = time.time()
  50. self.cursor.execute(sql)
  51. self.conn.commit()
  52. if MYSQL_DEBUG:
  53. log.info('sql: \n' + sql)
  54. log.info('sql cost: %s' % (time.time() - start))
  55. def getOne(self,sql, args=None):
  56. result = self.getData(sql, args)
  57. return result[0][0]
  58. def getData_pd(self, sql, args=None):
  59. start = time.time()
  60. # if args:
  61. # log.debug(sql % tuple(args))
  62. # else:
  63. # log.debug(sql)
  64. self.cursor.execute(sql, args=args)
  65. num_fields = len(self.cursor.description)
  66. field_names = [i[0] for i in self.cursor.description]
  67. df = self.cursor.fetchall()
  68. df = pd.DataFrame(data=list(df), columns=field_names)
  69. if MYSQL_DEBUG:
  70. sql_str = sql % tuple(args) if args else sql
  71. log.info('sql: \n' + sql_str)
  72. log.info('sql cost: %s' % (time.time() - start))
  73. return df
  74. def insertData(self, sql, args=None):
  75. # if args:
  76. # log.debug(sql % tuple(args))
  77. # else:
  78. # log.debug(sql)
  79. start = time.time()
  80. self.cursor.execute(sql, args=args)
  81. if MYSQL_DEBUG:
  82. sql_str = sql % tuple(args) if args else sql
  83. log.info('sql: \n' + sql_str)
  84. log.info('sql cost: %s' % (time.time() - start))
  85. self.conn.commit()
  86. def executeWithoutCommit(self, sql, args=None):
  87. return self.cursor.execute(sql, args=args)
  88. def commit(self):
  89. self.conn.commit()
  90. def insertorupdate(self, table, keys, tags, tagvalue, flag, *args):
  91. """
  92. :param table: 表名
  93. :param keys: 联合主键名元组
  94. :param tags: 字段名元组
  95. :param tagvalue: 字段值
  96. :param args: 主键值
  97. :param flag: 控制是否打印日志
  98. :return:
  99. """
  100. # log.info(tags)
  101. sql = "INSERT INTO " + table + " ("
  102. sql += ",".join(keys) + ","
  103. sql += ",".join(tags)
  104. sql += ") SELECT "
  105. sql += "%s," * len(keys)
  106. sql += ("%s," * len(tags))[:-1]
  107. sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table
  108. sql += " WHERE "
  109. for _ in keys:
  110. sql += _ + "=%s AND "
  111. sql = sql[:-4]
  112. sql += "LIMIT 1)"
  113. arg = list(args)
  114. arg.extend(tagvalue)
  115. arg.extend(list(args))
  116. rows = self.cursor.execute(sql, args=arg)
  117. if rows == 0:
  118. sql = "UPDATE " + table + " SET "
  119. for _ in tags:
  120. sql += _ + "=%s,"
  121. sql = sql[:-1]
  122. sql += " WHERE "
  123. for _ in keys:
  124. sql += _ + "=%s AND "
  125. sql = sql[:-4]
  126. arg = []
  127. arg.extend(tagvalue)
  128. arg.extend(list(args))
  129. self.cursor.execute(sql, args=arg)
  130. if flag:
  131. log.info(sql % tuple(arg))
  132. self.conn.commit()
  133. def _insertorupdate(self, table, keys, tags, tag_value, flag, key_value, update=False):
  134. if not update:
  135. sql = "INSERT INTO " + table + " ("
  136. sql += ",".join(keys) + ","
  137. sql += ",".join(tags)
  138. sql += ") SELECT "
  139. sql += "%s," * len(keys)
  140. sql += ("%s," * len(tags))[:-1]
  141. sql += " FROM DUAL WHERE NOT EXISTS (SELECT id FROM " + table
  142. sql += " WHERE "
  143. for _ in keys:
  144. sql += _ + "=%s AND "
  145. sql = sql[:-4]
  146. sql += "LIMIT 1)"
  147. arg = list(key_value)
  148. arg.extend(tag_value)
  149. arg.extend(list(key_value))
  150. rows = self.cursor.execute(sql, args=arg)
  151. if rows == 0:
  152. sql = "UPDATE " + table + " SET "
  153. for _ in tags:
  154. sql += _ + "=%s,"
  155. sql = sql[:-1]
  156. sql += " WHERE "
  157. for _ in keys:
  158. sql += _ + "=%s AND "
  159. sql = sql[:-4]
  160. arg = []
  161. arg.extend(tag_value)
  162. arg.extend(list(key_value))
  163. self.cursor.execute(sql, args=arg)
  164. if flag:
  165. log.info(sql % tuple(arg))
  166. else:
  167. sql = "UPDATE " + table + " SET "
  168. for _ in tags:
  169. sql += _ + "=%s,"
  170. sql = sql[:-1]
  171. sql += " WHERE "
  172. for _ in keys:
  173. sql += _ + "=%s AND "
  174. sql = sql[:-4]
  175. arg = []
  176. arg.extend(tag_value)
  177. arg.extend(list(key_value))
  178. self.cursor.execute(sql, args=arg)
  179. if flag:
  180. log.info(sql % tuple(arg))
  181. def _insert_on_duplicate(self, table, keys, tags, tag_value, flag, key_value):
  182. name_all = list(keys)
  183. name_all.extend(tags)
  184. arg = list(key_value)
  185. arg.extend(tag_value)
  186. arg.extend(tag_value)
  187. sql_name = '(' + ','.join(name_all) + ')'
  188. sql_value = '(' + ','.join(['%s'] * len(name_all)) + ')'
  189. sql_update = ','.join([_ + '=%s' for _ in tags])
  190. sql = """
  191. insert into %s
  192. %s
  193. VALUES %s
  194. ON duplicate key UPDATE %s
  195. """ % (table, sql_name, sql_value, sql_update)
  196. self.cursor.execute(sql, args=arg)
  197. if flag:
  198. log.debug(sql % tuple(arg))
  199. def insertorupdatemany(self, table, keys, tags, tag_values, key_values, flag=False, unique_key=False, update=False):
  200. """
  201. :param table: 表名
  202. :param keys: 联合主键名元组
  203. :param tags: 字段名元组
  204. :param tag_values: 字段值组(list or pd.DataFrame)
  205. :param key_values: 主键值组(list or pd.DataFrame)
  206. :param flag: 控制是否打印日志
  207. :param unique_key: keys 是否为table的 unique_key
  208. :return:
  209. ps: 效率(外网): rows / 50; 1000以上更新使用
  210. """
  211. if isinstance(tag_values, pd.DataFrame):
  212. list_tag_value = [list(tag_values.iloc[_, :]) for _ in range(len(tag_values))]
  213. else:
  214. list_tag_value = list(tag_values)
  215. if isinstance(key_values, pd.DataFrame):
  216. list_key_value = [list(key_values.iloc[_, :]) for _ in range(len(key_values))]
  217. else:
  218. list_key_value = list(key_values)
  219. for _ in range(len(list_tag_value)):
  220. tag_value = list_tag_value[_]
  221. key_value = list_key_value[_]
  222. if unique_key:
  223. self._insert_on_duplicate(table, keys, tags, tag_value, flag, key_value)
  224. else:
  225. self._insertorupdate(table, keys, tags, tag_value, flag, key_value, update)
  226. self.conn.commit()
  227. def _check_repeat_key(self, key_list):
  228. tmp = list(map(lambda x: tuple(x), key_list))
  229. if len(tmp) == len(set(tmp)):
  230. return False
  231. else:
  232. last_data = -1
  233. repeat_key = set()
  234. for i in sorted(tmp):
  235. if last_data == i:
  236. repeat_key.add(i)
  237. if len(repeat_key) >= 10:
  238. break
  239. last_data = i
  240. log.error('Reject repeated keys')
  241. log.error('repeat_key: %s' % repeat_key)
  242. return True
  243. def _convert_to_list(self, data):
  244. if isinstance(data, pd.DataFrame):
  245. # np.nan != np.nan 从而判断值为np.nan
  246. list_data = [map(lambda x: None if x != x else x, list(data.iloc[_, :])) for _ in range(len(data))]
  247. li =[]
  248. for i in list_data:
  249. li.append(list(i))
  250. list_data = li
  251. else:
  252. list_data = list(data)
  253. return list_data
  254. def _get_exist_keys_index(self, table, keys, key_values, flag=False):
  255. list_sql_when = []
  256. list_tmp = []
  257. for i in range(len(key_values)):
  258. sql_when = """when (%s)=(%s) then %s""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])), i)
  259. list_sql_when.append(sql_when)
  260. list_tmp.extend(key_values[i])
  261. list_sql_condition = []
  262. for i in range(len(key_values)):
  263. # sql_condition_old = """(%s)=(%s)""" % (','.join(keys), ','.join(['%s'] * len(key_values[i])))
  264. row_condition_list = map(lambda x: '%s = %%s' % x, keys)
  265. sql_condition = """(%s)""" % ' and '.join(row_condition_list)
  266. # print sql_condition_old, sql_condition
  267. list_sql_condition.append(sql_condition)
  268. list_tmp.extend(key_values[i])
  269. sql_where = ' or '.join(list_sql_condition)
  270. sql_case = '\n'.join(list_sql_when)
  271. sql = """
  272. select
  273. case
  274. %s
  275. end
  276. from %s
  277. where %s
  278. """ % (sql_case, table, sql_where)
  279. if flag:
  280. log.info(sql % tuple(list_tmp))
  281. self.cursor.execute(sql, tuple(list_tmp))
  282. print()
  283. result = self.cursor.fetchall()
  284. return map(lambda x: x[0], result)
  285. def insertorupdatemany_v2(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
  286. """
  287. 更新插入多条数据(无key时自动插入, 有keys时更新)
  288. :param table: 表名
  289. :param keys: 联合主键名元组
  290. :param tags: 字段名元组
  291. :param tag_values: 字段值组(list or pd.DataFrame)
  292. :param key_values: 主键值组(list or pd.DataFrame)
  293. :param flag: 控制是否打印日志
  294. :param split: 切割阈值
  295. :return:
  296. ps: 效率(外网): rows^2 / 50000; rows以split为单位分批更新
  297. """
  298. if not isinstance(tag_values, (tuple, list, pd.DataFrame)):
  299. log.error('Type Error')
  300. exit(-1)
  301. return
  302. if len(tag_values) > split:
  303. length = len(tag_values)
  304. for i in range(0, length, split):
  305. start, finish = i, i + split
  306. self.insertorupdatemany_v2(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
  307. return
  308. if len(key_values) == 0 or len(tag_values) == 0:
  309. log.debug('insert or update 0 rows')
  310. return
  311. tag_values = self._convert_to_list(tag_values)
  312. key_values = self._convert_to_list(key_values)
  313. assert self._check_repeat_key(key_values) == False
  314. exist_key_index = list(self._get_exist_keys_index(table, keys, key_values, flag))
  315. new_key_index = list(set(range(len(key_values))) - set(exist_key_index))
  316. update_keys = list(map(lambda x: key_values[x], exist_key_index))
  317. update_tags = list(map(lambda x: tag_values[x], exist_key_index))
  318. insert_keys = list(map(lambda x: key_values[x], new_key_index))
  319. insert_tags = list(map(lambda x: tag_values[x], new_key_index))
  320. self.insert_many(table=table,
  321. keys=keys,
  322. tags=tags,
  323. tag_values=insert_tags,
  324. key_values=insert_keys,
  325. flag=flag)
  326. self.update_many(table=table,
  327. keys=keys,
  328. tags=tags,
  329. tag_values=update_tags,
  330. key_values=update_keys,
  331. flag=flag,
  332. split=split)
  333. def insertorupdatemany_v3(self, df, table, keys, tags, flag=False, split=80):
  334. self.insertorupdatemany_v2(
  335. table=table,
  336. keys=keys,
  337. tags=tags,
  338. tag_values=df[tags],
  339. key_values=df[keys],
  340. flag=flag,
  341. split=split
  342. )
  343. def _get_s_format(self, data):
  344. """
  345. Args:
  346. data: [[featureA1, featureB1, ...], [featureA2, featureB2, ...], ...]
  347. Returns:
  348. format of %s and real value
  349. Example:
  350. [['2017-07-01', 78], ['2017-07-01', 1]] ->
  351. ('((%s, %s), (%s, %s))', ['2017-07-01', 78, '2017-07-01', 1])
  352. """
  353. list_tmp_s = []
  354. values = []
  355. for _ in data:
  356. tmp_s = ','.join(len(_) * ['%s'])
  357. values.extend(_)
  358. if len(_) > 1:
  359. tmp_s = '(' + tmp_s + ')'
  360. list_tmp_s.append(tmp_s)
  361. format_s = '(' + ','.join(list_tmp_s) + ')'
  362. return format_s, values
  363. def delete_by_key(self, table, keys, key_values, flag=False):
  364. """
  365. Args:
  366. table: 表名
  367. keys: 联合主键名元组
  368. key_values: 主键值组(list or pd.DataFrame)
  369. flag: 控制是否打印日志
  370. Examples:
  371. delete_by_key('table_test', keys=['date'], key_values=[['2017-07-01'], ['2017-07-02']], flag=False)
  372. delete_by_key('table_test', keys=['date'], key_values=['2017-07-01'], flag=False)
  373. """
  374. if len(key_values) == 0:
  375. return
  376. if not (isinstance(key_values[0], (list, tuple)) or isinstance(key_values, pd.DataFrame)):
  377. key_values_list = [key_values]
  378. else:
  379. key_values_list = self._convert_to_list(key_values)
  380. sql_keys = '(' + ','.join(keys) + ')'
  381. contact_s, values_s = self._get_s_format(key_values_list)
  382. sql_del = """
  383. delete from %s
  384. where %s in %s
  385. """ % (table, sql_keys, contact_s)
  386. if flag:
  387. log.debug(sql_del % tuple(values_s))
  388. self.cursor.execute(sql_del, tuple(values_s))
  389. self.conn.commit()
  390. def insert_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
  391. """
  392. 直接插入多条数据
  393. :param table: 表名
  394. :param keys: 联合主键名元组
  395. :param tags: 字段名元组
  396. :param tag_values: 字段值组(list or pd.DataFrame)
  397. :param key_values: 主键值组(list or pd.DataFrame)
  398. :param flag: 控制是否打印日志
  399. :return:
  400. Examples: 参照 insertorupdatemany_v2
  401. insert into table
  402. (count_date, cid, tag1, tag2)
  403. values ('2017-01-01', 10, 1, 'a'), ('2017-01-02', 20, 2, 'b'), ...
  404. """
  405. if len(key_values) == 0 or len(tag_values) == 0:
  406. log.debug('insert 0 rows')
  407. return
  408. if len(tag_values) > split:
  409. length = len(tag_values)
  410. for i in range(0, length, split):
  411. start, finish = i, i + split
  412. self.insert_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
  413. return
  414. tag_values = self._convert_to_list(tag_values)
  415. key_values = self._convert_to_list(key_values)
  416. feature_total = "(" + ",".join(keys + tags) + ")"
  417. tmp_s = "(" + ",".join(["%s"] * len(keys + tags)) + ")"
  418. tmp_s_concat = ",\n".join([tmp_s] * len(key_values))
  419. sql_insert = """
  420. Insert into %s
  421. %s
  422. values %s""" % (table, feature_total, tmp_s_concat)
  423. value_insert = []
  424. for _ in zip(key_values, tag_values):
  425. value_insert.extend(_[0] + _[1])
  426. if flag:
  427. log.debug(sql_insert % tuple(value_insert))
  428. t0 = time.time()
  429. self.cursor.execute(sql_insert,tuple(value_insert))
  430. log.debug('insert %s rows, cost: %s' % (len(key_values), time.time() - t0))
  431. self.conn.commit()
  432. def update_many(self, table, keys, tags, tag_values, key_values, flag=False, split=80):
  433. """
  434. 更新多条数据(无key时不会自动插入)
  435. :param table: 表名
  436. :param keys: 联合主键名元组
  437. :param tags: 字段名元组
  438. :param tag_values: 字段值组(list or pd.DataFrame)
  439. :param key_values: 主键值组(list or pd.DataFrame)
  440. :param flag: 控制是否打印日志
  441. :param split: 分批更新量
  442. :return:
  443. Examples: 参照 insertorupdatemany_v2
  444. # 单条 update sql tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10)
  445. update table
  446. set tag1=1, tag2='a'
  447. where (count_date, cid) =('2017-01-01', 10)
  448. # 多条组合 update sql
  449. # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10);
  450. # tag1=1, tag2='a' 插入到 (count_date, cid) =('2017-01-01', 10);
  451. update table
  452. set tag1 = case
  453. when (count_date, cid)=('2017-01-01', 10) then 1
  454. when (count_date, cid)=('2017-01-02', 20) then 2
  455. ...
  456. ,
  457. tag_2 = case
  458. when (count_date, cid)=('2017-01-01', 10) then 'a'
  459. when (count_date, cid)=('2017-01-02', 20) then 'b'
  460. ...
  461. where (count_date, cid)=('2017-01-01', 10) or (count_date, cid)=('2017-01-02', 20) or ...
  462. """
  463. if len(tag_values) > split:
  464. length = len(tag_values)
  465. for i in range(0, length, split):
  466. start, finish = i, i + split
  467. self.update_many(table, keys, tags, tag_values[start:finish], key_values[start:finish], flag, split=split)
  468. return
  469. if len(key_values) == 0 or len(tag_values) == 0:
  470. log.debug('update 0 rows')
  471. return
  472. tag_values = self._convert_to_list(tag_values)
  473. key_values = self._convert_to_list(key_values)
  474. if self._check_repeat_key(key_values):
  475. return
  476. update_value = []
  477. sql_keys = ','.join(keys)
  478. if len(keys) > 1:
  479. sql_keys = '(' + sql_keys + ')'
  480. sql_key_values = ','.join(['%s'] * len(keys))
  481. if len(keys) > 1:
  482. sql_key_values = '(' + sql_key_values + ')'
  483. sql_set_list = []
  484. for i in range(len(tags)):
  485. sql_when_list = []
  486. for j in range(len(tag_values)):
  487. sql_when = """when %s=%s then %s """ % (sql_keys, sql_key_values, '%s')
  488. update_value.extend(key_values[j])
  489. update_value.append(tag_values[j][i])
  490. sql_when_list.append(sql_when)
  491. sql_when_concat = '\n\t'.join(sql_when_list)
  492. sql_set = """%s = case \n\t %s\n end""" % (tags[i], sql_when_concat)
  493. sql_set_list.append(sql_set)
  494. for _ in key_values:
  495. update_value.extend(_)
  496. sql_set_concat = ',\n'.join(sql_set_list)
  497. list_sql_condition = []
  498. for i in range(len(key_values)):
  499. row_condition_list = map(lambda x: '%s = %%s' % x, keys)
  500. sql_condition = """(%s)""" % ' and '.join(row_condition_list)
  501. list_sql_condition.append(sql_condition)
  502. sql_where = ' or '.join(list_sql_condition)
  503. # condition = ' or\n\t'.join([sql_keys + '=' + sql_key_values] * len(tag_values))
  504. # print condition
  505. sql = """update %s\n set %s\n where %s""" % (table, sql_set_concat, sql_where)
  506. if flag:
  507. log.info(sql % tuple(update_value))
  508. t0 = time.time()
  509. self.cursor.execute(sql, tuple(update_value))
  510. self.conn.commit()
  511. log.debug('update %s rows, cost: %s' % (len(key_values), time.time() - t0))
  512. # class CkOperation:
  513. # cursor = None
  514. # session = None
  515. #
  516. # def __init__(self, conf):
  517. # try:
  518. # connection = 'clickhouse://{user}:{passwd}@{host}:{port}/{db}'.format(**conf)
  519. # engine = create_engine(connection, pool_size=100, pool_recycle=3600, pool_timeout=20)
  520. # self.session = make_session(engine)
  521. #
  522. # except Exception as e:
  523. # log.info(e)
  524. #
  525. # def execute(self, sql):
  526. # self.cursor = self.session.execute(sql)
  527. # try:
  528. # fields = self.cursor._metadata.keys
  529. # return [dict(zip(fields, item)) for item in self.cursor.fetchall()]
  530. # except Exception as e:
  531. # log.info(e)
  532. #
  533. # def getData_pd(self, sql):
  534. # li = self.execute(sql)
  535. # return pd.DataFrame(li)
  536. #
  537. # def getOne(self, sql):
  538. # li = self.execute(sql)
  539. # return [i for i in li[0].values()][0]