Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save peterbean410/6e555086f9d38f11e27d7cfa99907038 to your computer and use it in GitHub Desktop.
Save peterbean410/6e555086f9d38f11e27d7cfa99907038 to your computer and use it in GitHub Desktop.
MySQL-Python: Copy Sample Data From Source To Target

MySQL-Python: Copy Sample Data From Source To Target

# Rename this file to .env
# Source database
SOURCE_DB_HOST=localhost
SOURCE_DB_USER=root
SOURCE_DB_PASSWORD=
SOURCE_DB_NAME=source_db
# Target database
TARGET_DB_HOST=localhost
TARGET_DB_USER=root
TARGET_DB_PASSWORD=
TARGET_DB_NAME=target_db
import mysql.connector
from collections import defaultdict
import os
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
source_connection = mysql.connector.connect(
host=os.getenv("SOURCE_DB_HOST"),
user=os.getenv("SOURCE_DB_USER"),
password=os.getenv("SOURCE_DB_PASSWORD"),
database=os.getenv("SOURCE_DB_NAME"),
)
source_cursor = source_connection.cursor()
target_connection = mysql.connector.connect(
host=os.getenv("TARGET_DB_HOST"),
user=os.getenv("TARGET_DB_USER"),
password=os.getenv("TARGET_DB_PASSWORD", ""), # Default empty string if not set
database=os.getenv("TARGET_DB_NAME"),
)
target_cursor = target_connection.cursor()
source_cursor.execute("SELECT TABLE_NAME"
" FROM information_schema.TABLES"
" WHERE TABLE_SCHEMA = DATABASE()"
" AND TABLE_TYPE = 'BASE TABLE';")
tables = [row[0] for row in source_cursor.fetchall()]
source_cursor.execute("""
SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
FROM information_schema.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = DATABASE() AND REFERENCED_TABLE_NAME IS NOT NULL;
""")
foreign_keys = source_cursor.fetchall()
dependencies = defaultdict(list)
dependencies_columns = defaultdict(defaultdict)
for table, column, ref_table, ref_column in foreign_keys:
dependencies[table].append(ref_table)
dependencies_columns[table][column] = {
'ref_table': ref_table,
'ref_column': ref_column,
}
sorted_tables = []
visited = set()
def visit(table):
if table not in visited:
visited.add(table)
if table in dependencies:
for dependent in dependencies[table]:
visit(dependent)
sorted_tables.append(table)
def get_columns_info(table):
source_cursor.execute(f"DESCRIBE {table};")
columns_info = source_cursor.fetchall()
return columns_info
def get_idx_column_with_dependency(table, columns_info):
idx_column_with_dependency = {}
print(f'dependencies_columns {dependencies_columns[table]}')
for idx, column in enumerate([col[0] for col in columns_info]):
if column in dependencies_columns[table]:
idx_column_with_dependency[idx] = {
'column': column,
'ref': dependencies_columns[table][column],
}
return idx_column_with_dependency
def insert_rows_to_target(rows, table):
print('Inserting rows to target')
columns_info = get_columns_info(table=table)
idx_column_with_dependency = get_idx_column_with_dependency(table, columns_info)
default_values = {}
enum_allowed_values = {}
for column in columns_info:
column_name, column_type, is_nullable, key, default, extra = column
if "enum" in column_type.lower():
enum_str = column_type[column_type.find("(") + 1:column_type.find(")")]
allowed_values = [v.strip("'\"") for v in enum_str.split(",")]
enum_allowed_values[column_name] = allowed_values
default_values[column_name] = None
if is_nullable == "NO":
if "int" in column_type or "decimal" in column_type:
default_values[column_name] = 0
elif "char" in column_type or "text" in column_type:
default_values[column_name] = ""
elif "datetime" in column_type or "timestamp" in column_type:
default_values[column_name] = "1975-01-01 00:00:00"
elif "date" in column_type:
default_values[column_name] = "1970-01-01"
elif "time" in column_type:
default_values[column_name] = "00:00:00"
generated_columns = [
col[0] for i, col in enumerate(columns_info)
if "GENERATED" in col[5]
]
generated_columns_idx = [
i for i, col in enumerate(columns_info)
if "GENERATED" in col[5]
]
# Replace None values in rows with default values
# Process rows with ENUM validation
processed_rows = []
for row in rows:
processed_row = []
for i, value in enumerate(row):
if i in generated_columns_idx:
continue
column_name = columns_info[i][0]
is_nullable = columns_info[i][2]
if column_name in enum_allowed_values and value is not None:
if value not in enum_allowed_values[column_name]:
if is_nullable == "YES":
value = None
else:
default_enum = default_values.get(column_name)
value = default_enum if default_enum else enum_allowed_values[column_name][0]
if value is None:
default_value = default_values.get(column_name, value)
processed_row.append(default_value)
else:
processed_row.append(value)
processed_rows.append(processed_row)
for column in columns_info:
column_name, column_type, is_nullable, key, default, extra = column
if "enum" in column_type:
original_index = next((i for i, col in enumerate(columns_info) if col[0] == column_name), None)
if original_index is None:
continue
processed_index = sum(1 for i in range(original_index) if i not in generated_columns_idx)
if is_nullable == "YES":
for row in processed_rows:
if processed_index < len(row) and row[processed_index] == "":
row[processed_index] = None
else:
for row in processed_rows:
if processed_index < len(row) and row[processed_index] == "":
row[processed_index] = 'none'
if processed_rows:
placeholders_array = []
for column in columns_info:
if column[0] in generated_columns:
print(f'Skipping column {column[0]} as it is a generated column')
placeholders_array.append('DEFAULT')
continue
placeholders_array.append("%s")
placeholders = ", ".join(placeholders_array)
columns = ", ".join([f'`{col[0]}`' for col in columns_info])
insert_sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})" # ON DUPLICATE KEY UPDATE {update_clause}
print(f'Inserting data into {table}...')
print(f'\nInsert SQL:\n {insert_sql}')
print(f'\nRef Columns:\n len(idx_column_with_dependency) = {len(idx_column_with_dependency)}')
for idx, column in idx_column_with_dependency.items():
print(f'\t{idx} -> {column.get("column")}')
print(f'\t{idx} -> {column["ref"]["ref_table"]}')
print(f'\t{idx} -> {column["ref"]["ref_column"]}')
for row in processed_rows:
print(f'\n About to insert row: \t{row}')
print(f"\n But let's check for reference first")
insert_rows_ref_to_target_if_necessary(row, table)
formatted_sql = insert_sql % tuple(
f"'{value}'" if isinstance(value, str) else 'NULL' if value is None else value for value in row
)
print(f"Raw SQL: \n{formatted_sql}")
target_cursor.execute(insert_sql, row)
target_connection.commit()
def insert_rows_ref_to_target_if_necessary(row, table):
columns_info = get_columns_info(table=table)
idx_column_with_dependency = get_idx_column_with_dependency(table, columns_info)
for idx, column in idx_column_with_dependency.items():
print(f'\t{idx} -> {column}')
ref_table = column['ref']['ref_table']
ref_column = column['ref']['ref_column']
if row[idx] is None:
continue
query = f"SELECT * FROM {ref_table} WHERE {ref_column} = %s LIMIT 1;"
print(f'\n\tRef Check Query: {query}')
target_cursor.execute(query, [row[idx]])
rows_ref = target_cursor.fetchall()
if (len(rows_ref) == 0):
print(f'NEED TO INSERT {ref_table} {ref_column} = {row[idx]}')
source_cursor.execute(query, [row[idx]])
rows_ref_src = source_cursor.fetchall()
if len(rows_ref_src) == 0:
raise RuntimeError('Should not be empty')
insert_rows_to_target(
rows=rows_ref_src,
table=ref_table,
)
else:
print(f'\nREF already EXISTS {ref_table} {ref_column} = {row[idx]}')
for table in tables:
visit(table)
for table in sorted_tables:
print(f'Exporting table {table}')
target_cursor.execute(f"SHOW TABLES LIKE '{table}';")
if not target_cursor.fetchone():
source_cursor.execute(f"SHOW CREATE TABLE {table};")
create_table_sql = source_cursor.fetchone()[1]
create_table_sql = create_table_sql.replace("DEFAULT '0000-00-00'", "DEFAULT (CURRENT_DATE)")
create_table_sql = create_table_sql.replace("DEFAULT '0000-00-00 00:00:00'", "DEFAULT (CURRENT_TIMESTAMP)")
print(f'\nCreate Tabe SQL:\n {create_table_sql}')
target_cursor.execute(f"DROP TABLE IF EXISTS {table};")
target_cursor.execute(create_table_sql)
num_rows = 1
source_cursor.execute(f"SELECT * FROM {table} LIMIT {num_rows};")
rows = source_cursor.fetchall()
insert_rows_to_target(rows, table)
source_cursor.close()
source_connection.close()
target_cursor.close()
target_connection.close()
mysql-connector-python==8.0.33
python-dotenv==1.0.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment