Final Udacity Nano-degree Project: Predicting Sparkify Churn

Kate Asarar
11 min readNov 19, 2020

Udacity offers a handful of very interesting projects and datasets to use for a final project of the data scientist nano-degree. I chose a project where I had to work with a completely unknown platform to me, e.i. Apache Spark, so that I can become more familiar with it as well as get more experience handling Big Data. I learned a lot during the last couple of weeks and I’m happy to share my findings and my project with you guys on this post! :D

Problem statement, lets start with why!

Why do we care about churn, and what is it really? Churn is when a user chooses to cancel or downgrade a service. It can happen due to many reasons, the financial situation of the user, their needs for the service changing etc.

The churn rate that companies care about however, is the one caused directly related to their product. A user might opt to cancel a subscription because they cant find what they want, because the GUI design is bad and they keep having to visit the help page to understand what to do, or because they don’t feel like they are getting a good deal on it. It is in these cases where being able to predict churn can make the biggest difference for a company. By offering a wider selection, updating their platform or even sending out groupons to the soon to churn users, they can keep the customers happy and retain their subscritptions for much longer.

Solution stategy and project description

In this project, a smaller set of data 128 Mb was given to experiment on. The project will go through the stages of CRISP-DM to analyze, transform, modelize and evaluate the created prediction models. The expected results of this project are data analysis as well as trained machine learning model that are able to classify a user as a churning user or one that will stay.

Once all stages of cleaning, transforming, feature extraction, modeling and evaluating are ready, the scripts can be deployed to a AWS cluster to run on the 12 GB dataset stored in a AWS S3 bucket using Spark.

Exploring the data

First lets have a look at what kind of information is stored in the userlog dataset:

data_original.printSchema()root
|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

let’s have a look to see if there is any missing data and remove it:

# Check number of null values 
Dict_Null = {col:data_original.filter(data_original[col].isNull()).count() for col in data_original.columns}
Dict_Null
{'artist': 58392,
'auth': 0,
'firstName': 8346,
'gender': 8346,
'itemInSession': 0,
'lastName': 8346,
'length': 58392,
'level': 0,
'location': 8346,
'method': 0,
'page': 0,+
'registration': 8346,
'sessionId': 0,
'song': 58392,
'status': 0,
'ts': 0,
'userAgent': 8346,
'userId': 0}

There could however be empty values, let’s check for those as well:

# check number of distinct values in every column
Dict_distinct = {column:data_original.select(col(column)).distinct().count() for column in data_original.columns}
Dict_distinct
{'artist': 0,
'auth': 0,
'firstName': 0,
'gender': 0,
'itemInSession': 0,
'lastName': 0,
'length': 0,
'level': 0,
'location': 0,
'method': 0,
'page': 0,
'registration': 0,
'sessionId': 0,
'song': 0,
'status': 0,
'ts': 0,
'userAgent': 0,
'userId': 8346}
data_original.count()
286500

It looks like there are 8346 rows with no userId out of the original 286500 rows ~ 3% of the data.

# check pages visited where userId is missing 
data_original.filter(col('userID') == '').select('page').distinct().show()
+-------------------+
| page|
+-------------------+
| Home|
| About|
|Submit Registration|
| Login|
| Register|
| Help|
| Error|
+-------------------+

Based on the visited pages where userIds are missing we can conclude that they are only missing for users that are not logged in/ registered yet.

I will not include all of the steps that I took to clean and prepare the data here but they can be found in my Github!
Before we move on let’s have a peek at the first row of data:

data_original.show(1)+--------------+---------+---------+------+-------------+--------+---------+-----+---------------+------+--------+-------------+---------+---------+------+-------------+--------------------+------+
| artist| auth|firstName|gender|itemInSession|lastName| length|level| location|method| page| registration|sessionId| song|status| ts| userAgent|userId|
+--------------+---------+---------+------+-------------+--------+---------+-----+---------------+------+--------+-------------+---------+---------+------+-------------+--------------------+------+
|Martha Tilston|Logged In| Colin| M| 50| Freeman|277.89016| paid|Bakersfield, CA| PUT|NextSong|1538173362000| 29|Rockpools| 200|1538352117000|Mozilla/5.0 (Wind...| 30|
+--------------+---------+---------+------+-------------+--------+---------+-----+---------------+------+--------+-------------+---------+---------+------+-------------+--------------------+------+

Data Analysis

Lets plot and describe the data to see what we can use as features for the machine learning models. Some interesting features to consider are how many songs a users liked versus not liked during their subscription period, how many errors did they run into and how many times was the artist they wanted missing. The data is therefore cleaned and reduced to the interesting columns as is shown below.

check_churn = udf(lambda ischurn: int((ischurn == 'Cancellation Confirmation') | \
(ischurn == 'Submit Downgrade')), IntegerType())
data = data.filter(col('userId') != '') \
.withColumn('churn', check_churn(col('page'))) \
.withColumn('ts', data.ts/1000.0) \
.withColumn('date', from_unixtime("ts", "yyyy-MM-dd HH:mm:ss")) \
.withColumn("date", col("date").cast("timestamp"))

data = data.select('churn', 'date', 'status', 'userId', 'page', 'level', 'gender', 'sessionId', 'artist')

Plotting the count of each distinct value where there are less than 30 distin

As we can see in the plots, the there are 4 times more paying users than ones using the free option. The gender distribution among users is relatively equal. There are however a lot of missing artist which could be a good feature to include as well as the different page count for the users.

Let’s analyse the behaviour and experiences of each group of the churned and stayed users. We will do this by looking at the difference in visited pages, number of missing artist and these values when taking into account the duration of their stay. Finally we will look at how many sessions it takes to churn.

First two different dataframes are created for each user type and we can calculate how many users belong to each category.

# check total number of userIds in database
all_userIds = data.select('userId').distinct()
print('Number of unique users in dataset: {}'.format(all_userIds.count()))
# get churned userIds
churned_userIds = data.filter('churn == 1').select('userId').distinct()
print('Number of users that churned: {}'.format(churned_userIds.count()))
# create stayed users dataframe
stayed_usersIds = data.select('userId').distinct().subtract(churned_userIds)
print('Number of users that stayed: {}'.format(stayed_usersIds.count()))
output: Number of unique users in dataset: 225
Number of users that churned: 92
Number of users that stayed: 133

Lets start by having a look at how many sessions a user will have before desiding to leave:

Per quantile: 25% of the users churn after 5 sessions or less.50% of the users churn after 9 sessions or less.75% of the users churn after 14 sessions or less.

As 75% of users churn after 14 sessions of less, the number of session a user has had so far can be a useful feature and will be added to the set of features to train the machine learning model on.

Moving on, lets have a look at the length of each session and the total subscription length in hours.

In these plots we can clearly see that the average session is relatively equal for the two user groups up untill the 50 hour mark where stayed users continue having longer and longer sessions while the churned users have shorter sessions.

Lastly let’s plot the differences in count of visited pages per group normalized by total number of users in the group and the difference between the two:

The plot shows that the churned users are interacting with more pages than the staying users. This could be seen as an indication that the users that churn actually start off liking the service and using it often but do not find it useful in the long run. The relatively higher rate of error pages seen for churned users follow the trend of increased aggregated page visit and does not indicate that churned users get a faulty version of the software or have a worse experience in general than the staying users.

Stayed users are still leading in adding songs to playlists and giving thumbs up which are considered as positive interactions with the product. Based on this analysis I chose to user the page count/ duration will be used as a feature for the machine learning models.

Feature extraction

The final set of features that was created for the machine learning model contain all the relevent information that is discussed in this blog. As a summary, here are the columns that were included.

['Error',
'Help',
'Home',
'Logout',
'NextSong',
'Roll Advert',
'Thumbs Down',
'Thumbs Up',
'Add to Playlist',
'duration in hours',
'churn',
'nr_sessions',
'missing_artist',
'activity',
'Error/duration',
'Help/duration',
'Home/duration',
'Logout/duration',
'NextSong/duration',
'Roll Advert/duration',
'Thumbs Down/duration',
'Thumbs Up/duration',
'Add to Playlist/duration',
'duration in hours/duration',
'churn/duration',
'nr_sessions/duration',
'missing_artist/duration',
'activity/duration']

Machine learning models and metrics

Two models are used in this work to predict churn, a logistic regression model and a linear support vector machine. Both were setup using the CrossValidator() tool in pyspark to choose the best model after grid searching through a grid of hyperparameters, contructed using ParamGridBuilder().

Both models are evaluated using F1-, Recall and precision scores. Where F1-score is given by the harmonic mean of the Recal and Precission scores and is used to evaluate the models overall performance. Recall is used to measure the models tendency to miss classifying churning users and is therefore arguably the most imoportant metrics for this project. On the other hand the Precision score is used to detected mainly how well the model is able to correctly detect staying users and is therefore valuable to improve for cutting losses linked to giving out spcial offers to retain falsely classified staying users as churning users.

Refinement

The models that are trained in this project needed some tweaking in terms of hyperparameters selection. Since the tools CrossValidator() and ParamGridBuilder() were used however, this was done automatically.

The logistic regression model was optimized using the following hyperparameters:

lr_paramGrid = ParamGridBuilder()\
.addGrid(lr_model.elasticNetParam, [0, 0.5, 1])\
.addGrid(lr_model.regParam, [0.1, 0.01])\
.addGrid(lr_model.maxIter, [50, 100, 200])\
.build()

where elasticNetParam is a parameter that sets the ratio between an L1 and L2 penalty in an elastic net regularization and regParam is regularization parameter that is applied to the weights of eatch parameter to avoid overfitting. The parameter maxIter is used to set the maximum number of allowed learning iterations.

The linear support vector machine was optimized using the following parameters the regParam as explained before and the maximum allowed number of iterations:

svm_paramGrid = ParamGridBuilder()\
.addGrid(svm_model.maxIter, [50, 100])\
.addGrid(svm_model.regParam, [0.1, 1.0])\
.build()

Furthermore, the missing artist count was not used as a feature in the first draft of this project and the resulting metrics ended up being an F1-Score of 0.84, a Recall-score of 0.8076 and a Precision of 0.875. The models best performence after adding said feature is presented in the results section.

Model evaluation

The models were evaluated using the three most significant metrics scores, F1-, Recall and Precision scores which were calculated by first trasfomring the model predictions into a pandas dataframe and then using the builtin functions in the SKlearn library to calclate the scores. Furthermore, the confusion matrix of the models was plotted to visualize the end prediction results. For this the following two helper functions were used.

def print_metrics(predictions):

# Calculate and print f1, recall and precision scores
y_test = predictions.label
y_pred = predictions.prediction
f1 = f1_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
model_metrics = {'f1' : f1,
'recall': recall,
'precision': precision}
print('F1-Score: {}, Recall: {}, Precision: {}'.format(f1, recall, precision))

return model_metrics

def plot_confusion_matrix(predictions):
y_test = predictions.label
y_pred = predictions.prediction
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, cmap='Blues')

please refer to the results section for the discussion of the evaluation.

Results

and was able to achieve an F1-Score of 0.894, a Recall-score of 0.85 and a Precision of 0.94 with the resulting confusion matrix of:

Confussion Matrix Logistic Regression

As can be seen in the logistic regression’s nconfusion matrix, only one of the 28 staying users was falsely labeled as churning and only 3 of the shurning users were labeled as staying.

The support vector machine was less successful in predicting churn among users, reaching an F1-Score of 0.6, a Recall score of 0.45 and a high Precision of 0.9 with a confusion matrix as shown in the plot below.

Confusion Matrix SVM

As can be seen in the support vector machines confusion matrix, only one of the 28 staying users was falsely labeled while the majority of the churning users, 11 out of 20 were falsely labeled as staying. This is taken as an indication of the unfit of this type model for this type of problem. A meassure that can be taken to improve the performance is to increase the number of maximum learning iterations. The learning was however too slow and no significent improvements could be expected from this procedure so it was opted out.

The difference in performance between the two models is relatively expected, as the logistic regression model is better suited for binary classification problems with a low number of features and a clearly defined dependent variable. Support vector machine however, are also a powerful tool suited for a range of different classification and regression problems. By comparing the two we are better able to see the clearly better fit of the logistic regression.

Conclusion

This projects aims at understanding and predicting churn in a small user dataset by using Apache Spark’s analyrics engine and its powerful libraries.

The project goes through the different technical steps of the CRISP DM process starting with data understanding, followed by feature extraction, modeling and finally evaluation.

Using the methods presented in this jupyter-notebook, and the best performing model which is the logistic reression classifier, we are able to correctly identify 85% of unhappy users and offer them discounts or rewards to keep their subscriptions. Furthermore, with the high precision score 94.5% achieved by the model , we are able to save money by avoiding giving out offers to already happy and staying customers.

Future work can include adding a location feature to include the influence of internet connection quality on churn as well as trying out other machine learning models such as Random Forest classifiers.

Here is the link to the github repo for this project. Please have a look and leave a comment with some feedback, I would love to hear what you think :D

https://github.com/kate-asarar/FinalProject_spark

--

--