Using partial
October 01, 2023
Some libariries in python provide methods or functions that enable the user to pass a function
object as an argument. One limitation that I faced was the impossibility to pass extra arguments associated with the user defined function
object. This was a problem until I discover the role of partial
.
The problem
Two use cases, same problem:
from functools import partial
from airflow import DAG
from airflow.models import Variable
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
import logging
import sqlalchemy
import os
def process_bi_tables():
prefix = "parent/path/folder"
bucket = Variable.get("BUCKET")
region = Variable.get("REGION")
project_id = Variable.get("PROJECT_ID")
postg_connex_string = f"{project_id}:{region}:{Variable.get('POSTGRE_SERVER_NAME')}"
def postgre_pool():
connector = Connector()
return connector.connect(
instance_connection_string= postg_connex_string,
driver="your_driver",
user=os.environ.get("DB_USER"),
password=os.environ.get("DB_PASS"),
db=os.environ.get("DB_NAME"),
)
postgre_pool = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=postgre_pool
)
tables = ["patients_doctors", "nights", "users", "doctors", "sensors"]
views = ["nights_view"]
with postgre_pool.connect() as postgre_conn:
for name in tables + views:
logging.info(f"Processing table : {name}")
# read_query is not defined here to be brief
query = read_query(bucket=bucket, prefix=prefix, filename=name)
postgre_conn.execute(query)
logging.info("BI databases updated")
with DAG(
dag_id="pipeline_bi",
default_args={"owner": "airflow"},
schedule_interval="0 */3 * * *",
start_date=days_ago(1),
catchup=False,
) as dag:
update_bi_tables = PythonOperator(
task_id="update_bi_tables",
python_callable=partial(
process_bi_tables,
),
)
Let's have a look to the process_bi_tables
function:
- The purpose of this functions is to realize some data processing suing SQL scripts. We iterate through tables names and load
sql
scripts named the same way. The BI pipeline and Composer instances are deployed in different environments in US and Europe. By default Cloud Composer set up a bucket where you can find theairflow.cfg
file and upload your dags and python modules. You can add there any folders and files you need to use. In our case, we need to load thesesql
scripts. - First, you can see in the definition of the
pocess_bi_tables
function that we defined a variablebucket
andprefix
. They are defined inside the function which is not convenient if I want the code to be modular for each environment and avoid having to define each time some airflow variables. Thepocess_bi_tables
is passed to thePythonOperator
as argument. - Second, ntoice the way I created the
postgre_pool
function. It is an inner functions that return aConnector
object that will be necessary to setup the connection with my database. This functions is passed tosqlalchemy.create_engine
as argument.
I forced myself to define these two functions without any parameters because at first I had no idea how to use the function
object as argument and pass its args
. To circumvent this issue I had to define some variables inside the function body and this makes the code poorly mainainable and not that beautiful to read. This was until I found partial
.
partial
in action
In the documentation you read:
Returns a new partial object which when called will behave like
func
called with the positional argumentsargs
and keyword arguments keywords. If more arguments are supplied to the call, they are appended toargs
. If additional keyword arguments are supplied, they extend and override keywords.
This looks exactly what I need. Let's refactor the code using partial
, first with postgre_pool
:
def connect(connstring: str, user: str, password: str, db: str, driver: str):
connector = Connector()
return connector.connect(
instance_connection_string=connstring,
driver=driver,
user=user,
password=password,
db=db,
)
def process_bi_tables():
prefix = "parent/path/folder"
bucket = Variable.get("BUCKET")
region = Variable.get("REGION")
project_id = Variable.get("PROJECT_ID")
postg_connex_string = f"{project_id}:{region}:{Variable.get('POSTGRE_SERVER_NAME')}"
postgre_pool = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=partial(
utils.connect,
user=Variable.get("POSTGRE_DB_USER"),
password=Variable.get("POSTGRE_DB_PASSWORD"),
db=Variable.get("POSTGRE_DB_NAME"),
driver="pg8000",
connstring=postg_connex_string,
),
)
We factored out the postgre_pool
function into a function called connect
. To be used in our create_engine
we used the following construct:
partial(
my_functions,
**kwargs
)
The beauty of partial is that it will passed the function
object but also its arguments. It is pretty convenient. Moreover we can use connect
for any database and we do not need to use inner functions with handcoded arguments.
We can do a similar move for the process_bi_tables
function:
def process_bi_tables(bucket: str, prefix: str):
region = Variable.get("REGION")
project_id = Variable.get("PROJECT_ID")
postg_connex_string = f"{project_id}:{region}:{Variable.get('POSTGRE_SERVER_NAME')}"
postgre_pool = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=partial(
utils.connect,
user=Variable.get("POSTGRE_DB_USER"),
password=Variable.get("POSTGRE_DB_PASSWORD"),
db=Variable.get("POSTGRE_DB_NAME"),
driver="pg8000",
connstring=postg_connex_string,
),
)
tables = ["patients_doctors", "nights", "users", "doctors", "sensors"]
views = ["nights_view"]
with postgre_pool.connect() as postgre_conn:
for name in tables + views:
logging.info(f"Processing table : {name}")
query = utils.read_query(bucket=bucket, prefix=prefix, filename=name)
postgre_conn.execute(query)
logging.info("BI databases updated")
Now both bucket
and prefix
are function parameters. With partial we can use them in our PythonOperator
:
with DAG(
dag_id="pipeline_bi",
default_args={"owner": "airflow"},
schedule_interval="0 */3 * * *",
start_date=days_ago(1),
catchup=False,
) as dag:
update_bi_tables = PythonOperator(
task_id="update_bi_tables",
python_callable=partial(
process_bi_tables,
bucket=Variable.get("BUCKET-AIRFLOW"),
prefix="query/task_bi_processing",
),
)
Conclusion
I am glad I find my way through partial
. It was specifficaly designed for this use case and it is very elegant to use in your code.