import sqlparse
import psycopg2
import re
def format_sql(sql_content):
    '''将sql语句进行规范化,并去除sql中的注释,输入和输出均为字符串'''
    parse_str=sqlparse.format(sql_content,reindent=True,strip_comments=True)
    return parse_str

def extract_table_names(sql_query):
    '''从sql中提取对应的表名称,输出为列表'''
    table_names = set()
    # 解析SQL语句
    parsed = sqlparse.parse(sql_query)
    # 正则表达式模式,用于匹配表名
    table_name_pattern = r'\bFROM\s+([^\s\(\)\,]+)|\bJOIN\s+([^\s\(\)\,]+)'
    # with 子句判断
    with_pattern = r'with\s+(\w+)\s+as'
    remove_with_name=[]
    
    # 遍历解析后的语句块
    for statement in parsed:
        # 转换为字符串
        statement_str =  str(statement).lower()
        
        #将字符串中的特殊语法置空
        statement_str = re.sub('(substring|extract)\s*\(((.|\s)*?)\)','',statement_str)
        
        # 查找匹配的表名
        matches = re.findall(table_name_pattern, statement_str, re.IGNORECASE)
        
        for match in matches:
            # 提取非空的表名部分
            for name in match:
                #if name and name not in not_contain_list:
                if name :
                    # 对于可能包含命名空间的情况,只保留最后一部分作为表名
                    table_name = name#.split('.')[-1]
                    #去除表名中的特殊符号
                    #table_name = re.sub('("|`|\'|;)','',table_name)
                    table_names.add(table_name)
        #处理特殊的with语句
        if 'with' in statement_str:
            match = re.search(with_pattern, statement_str)
            if match:
                result = match.group(1)
                remove_with_name.append(result)
    table_list = list(table_names)
    #移除多余的表名
    if remove_with_name:
        table_list =list(set(table_list)-set(remove_with_name))
    return table_list


mydb = psycopg2.connect(
  host="x.x.x.x",
  user="postgres",
  password="postgres",
  database="ppp"
)

mycursor = mydb.cursor()

sql = "SELECT  method_name , \"sql\"  FROM public.api_sql"

mycursor.execute(sql)

rows = mycursor.fetchall()
for row in rows:
  print(row[0])
  sql = row[1]
  # 解析 SQL 查询语句
  parse_str=format_sql(sql )
  table_names = extract_table_names(parse_str)
  # 打印提取的表名
  print(table_names) 

mydb.close()