This article is describing the journey of implementing the unsupervised learning algorithm isolation forest for anomaly / outlier detection from scratch.
- I could not find a Scala implementation library and implementing the complete algorithm from scratch gave me a good understanding of the algorithm including it’s good and bad parts …
- I wrapped the implementation in a small project that can read actual data from a Kafka topic …
- Everybody needs a hobby …
What’s the goal?
The algorithm is an unsupervised learning algorithm that can inspect numerical multidimensional data and can identify outliers / point anomalies.
Let’s split up the above statement and explain it in detail.
- The algorithm works on numerical multidimensional data : This basically means that you have some entities that have numerical properties. An example would be a customer order which has the following numerical properties : a price, a margin, the creation throughput minutes , etc.
- Outliers / point anomalies detection : The algorithm identifies abnormal rare items, events or observations which raise suspicions by differing significantly from the majority of the data.
That sounds simple?
Identifying a point anomaly is an easy task in one dimension. If I were to give you the following order margins, then it would be easy to identify the abnormal margins in the below list :
[30%, 29%, 31%, 29%, 27%, 32%, 33%, 28%, 4%, 29%]
The 4% margin stands …
But now let’s add the following complexity:
- From an automation perspective you don’t know the distribution of the margin. Maybe it’s a nice normal distribution, and then you can check the deviation for the average to calculate if a point is an outlier but that’s an assumption which is not always true. The below distribution has anomalies that are close to the average (the middle red cross).
- An order has a lot of properties and when we add multiple properties, then the problem becomes much harder (since each properties again has an unknown distribution). Let’s explain this in detail in the below section.
We will describe the problem with 2 dimensions because it’s easy to visualize in 2 dimensions, but in real cases the number of dimensions can be very large.
In the below picture we look at the margin and the total price of a collection of orders (each cross is an order). The red crosses are clearly anomalies and the black crosses are clearly normal cases.
- If you would only look at the price, then the red crosses would appear normal.
- If you would only look at the margin, then you have the same story.
If you would have thousands of orders and you would have more then 2 dimensions, then it would be a hard job to manually detect the anomalous orders.
The goal is to detect the anomalous entities in a real time fashion which means that we don’t want to collect the data in a CSV, a database or some other medium and do the anomaly detection the day after the data collection which is normal in a BI report/context.
I want to read a live stream of data and detect the anomalous entities on a live data stream. I have used a Kafka topic as an example, but it can be any streaming data medium.
I have chosen a tumbling window stream strategy based on the configurable record count. The below picture shows the flow that we want to implement in a continuous loop:
- Take the next 500 records in the stream of data
- Perform anomaly detection on those 500 records and if there is an anomaly, perform some custom logic.
Isolation forest algorithm
This section focuses on the anomaly detection part of the solution. To understand our solution we first need to talk about the different types of anomaly detection. I already stated that the algorithm will identify point anomalies or outliers.
Point anomaly detection?
Point anomaly detection is a specific type of anomaly detection. As the term says, is an instance that could be considered as anomalous among other instances in the dataset. Point anomalies often represent some extreme value, irregularity or deviation that happens randomly and have no particular meaning. They are also called outliers. This type of anomaly detection is possible on 1 or multiple dimensions.
When you add a time of an other contextual feature in the mix, then we talk about contextual anomalies and finding these anomalies is not our goal for this article.
Isolation forests algorithm
Anomaly detection usually happens in scenario’s where you don’t have a balanced dataset (= you have equal examples for each case of interest). Let’s consider insurance fraud. Suppose we have a dataset of transactions and we know that there is a lot of fraud.
It’s an expensive investigation and you better have a good argument when investigating a customer for fraude. When looking at a fraud detection dataset we may have 1% labeled fraud cases but there are probably around 10% fraude cases.
This put’s us in a tough spot. We don’t know the 9% fraud cases. They were never discovered. That makes it hard to build a dataset of normal cases. Most anomaly detection algorithms require a data set of normal cases (for example auto encoders).
The Isolation forests algorithm requires no normal cases labeling. It doesn’t expect to model the normal cases so that you can measure the deviation from the normal cases when trying to find the anomalies.
Isolation forests remove the labeling work and take a simple but effective approach with other advantages:
- It has a low linear time complexity and a small memory requirement
- It is able to deal with high dimensional data with irrelevant attributes
- It can be trained with or without anomalies in the training set
Each point in the training set can be evaluated using the result from the training and you can evaluate new points on the generated isolation forest model. In our scenario we want to train the algorithm every single time when a new chunk of data is evaluated. This can only be done if your training process is fast so that you can use the evaluation in a real time fashion.
If your training process is not fast enough, then it’s wise to train the model once every X stages and use that model for inference in the next stages until a new training is scheduled.
In this article we will focus on the strategy that is in the above picture. Every time we take a chunk of data, then we train a model on that data and use that model again on the training data to get the results for each entity in the training data.
Intuition behind isolation forests ?
Instead of trying to build a model of normal instances, it explicitly isolates anomalous points in the dataset. It does it via a very simple principle. Anomalous data points in a sample have the following properties:
- Few — they are the minority consisting of fewer instances and
- Different — they have attribute-values that are very different from those of normal instances
So how does it exploit this? Let’s look at this following simple one dimensional problem setting. We have some blue points that are spread in one dimension.
We choose a random split (black line below) somewhere between the first point and the last point. This gives us 2 parts for the one dimensional dataset.
For each part we do the same thing again. We choose a random split somewhere between the first point and the last point (red line below).
And this action keeps on going until we have one point in each subpart. Basically we keep represent this splitting mechanism as a binary tree structure.
This means that we have a binary tree structure and outliers are probably higher in the tree. In other words: the number of nodes you pass from the root to the node of a point is probably lower for outlier points.
This was a very simple example in one dimension. Let’s look at 2 dimensions and remember the explanation regarding multidimensional points in the beginning of the article…
The red point is clearly an anomaly …
But we don’t have to luxury to look in one dimension to search for anomalies. If I would only look along the orange axis, then the red point is not an anomaly. If I only look along the green axis, then the red point is also not an anomaly.
Let’s see how the splitting strategy we discussed in on dimension might help. We choose one axis and split the data in 2 pieces (with a randomly chosen line in black)
For each piece we do the same action. We split each area with a line but this time we choose the other axis for our random line (green lines).
We see that our red point is isolated more quickly then our blue points. This was just a toy example, but the isolation forest algorithm exploits this idea.
Great … So we have this splitting strategy where we are actually building a tree structure and our anomalies have a higher change of ending up high in the tree.
Since this splitting strategy has a random non deterministic character, it’s not really reliable for one specific tree. The random forest isolation algorithm solves this by building a lot of trees (hence the name :-))
In the above picture you can see that we have build 4 trees in which our data point is located in different depths of the generated trees.
- tree left top : 4
- tree right top : 3
- tree left bottom : 3
- tree right bottom : 2
We simply take the average of the tree depths in our forest of trees which is our indication that our data point is an anomaly (avg = 3 in our case).
Some extra important things to note :
- Each tree is build with a relatively small subset of all the data (subsampling)
- Every split will be done on a random dimension in our data.
- Each tree will have a max depth. This means that you only split and at some point you just stop. This gives an “incomplete” tree but that’s fine. It will keep the memory print very low.
The subsampling strategy helps for the following reasons:
- Swamping: when normal instances are too close to anomalies, the number of partitions required to separate anomalies increases, a phenomenon known as swamping, which makes it more difficult for an algorithm to discriminate between anomalies and normal points. One of the main reasons for swamping is the presence of too many data for the purpose of anomaly detection, which implies one possible solution to the problem is sub-sampling. Since isolation forest responds very well to sub-sampling in terms of performance, the reduction of the number of points in the sample is also a good way to reduce the effect of swamping.
- Masking: when the number of anomalies is high it is possible that some of those aggregate in a dense and large cluster, making it more difficult to separate the single anomalies and, in turn, to detect such points as anomalous. Similarly to swamping, this phenomenon (known as “masking”) is also more likely when the number of points in the sample is big, and can be alleviated through sub-sampling.
Our code will provide a Kafka consumer implementation for the data ingestion part. Our consumer should perform anomaly detection on the data from 1 topic with JSON data.
A Kafka topic can have multiple partitions and each partition can be assigned to a specific consumer in a consumer group. Our anomaly detection algorithm should inspect the complete data from the topic (coming from multiple partitions). This means that our mechanism should be optimized (non blocking parallel processing where possible) and can’t be scaled / split over multiple consumers in a consumer group since we don’t want to store our data in some medium (a queue or database or … ) and perform the anomaly detection on that persistent medium.
Let’s go through the code step by step. Each section represents a particular portion of the code and all source code can be found on github here.
Our code should inspect Kafka topics for anomalies. We provide an easy configuration file where you can specify the following:
- The Kafka topic that we want to inspect.
- The Kafka consumer / group id settings
- The bootstrap servers for your connection with Kafka
- The JSON label that contains the unique ID of our record (we don’t assume that there is a Kafka key in the message)
- The JSON label anomaly fields that contain our numeric data that we want to inspect
Our sample configuration file looks like this:
Building a tree
Our tree will be a binary tree structure where you have internal nodes and external nodes.
- An external node is a leaf in the tree.
- An internal node will be split into 2 nodes that can be either an internal node or an external node.
Remember that each split is done by choosing a random variable/dimension in the data. We can represent our tree as a recursive algebraic data type structure. Our tree consists of ITree nodes which are either internal or external nodes and depending on the type we have some extra data members.
/** ADT for the Isolation tree */
sealed trait ITree[+A]case class ExternalNode[A](value: A, featureName: String, featureIndex: Int) extends ITree[A]case class InternalNode[A](splitValue:Double, featureName: String, featureIndex: Int, left: ITree[A], right: ITree[A]) extends ITree[A]
In the below section we can find the pseudocode from the paper for our tree building recursive function. We need a function that receives some input data and a height limit for our output tree.
Since it’s a recursive function we have three parameters:
- The input (which is a subset of all the data)
- The current tree height
- The maximum depth of the tree
First we need some helper functions:
- A function that can generate a random double between the min and max of a dimension in our data (generate random split)
- A function that generates a random integer for selection a feature in our data
- A function that can split our data (a list) in 2 lists based on the index of the feature.
Every tree that we will build might not have external nodes for each element since we define a max tree depth or the data may contain duplicates at some point. In the below picture you can see that we reached the max depth limit of the tree. We will not split up the data in the red section into subtrees (which means that some nodes in the red section might have more data points)
If the data in our node only consists of duplicates for the selected split dimension, then we will also not split up that node (it’s useless since all values are the same in that randomly chosen dimension).
So let’s build our tree …
We have a way to build a tree with a subsample of our data and we use a random dimension for each split and our tree will not always have a single data record in each external node / leaf (because of duplicate values in a dimension or because of the tree depth).
Checking the depth of a record in our tree structure
We now need a way to check the tree depth, but this is a problem if our binary tree is not always complete (= not every external node only contains one element).
What to do with those “incomplete” tree leaf nodes? The paper proposes the following formula based on the fact that our tree is a binary search tree.
Every time we hit a leaf with n elements (=incomplete), then we add the following to the already calculated depth and H(i) can be defined as :
This formula uses gamma which is Euler’s constant in the paper.
This gives us the following functions:
- Path length : Our function that calculates the depth of a record in our tree.
- C : Our function that compensates if a leaf has multiple values
Building the forest
- We want to build a lot of trees based on a subsample of the data.
- We want to use random dimensions for each split in a tree
- We want to build small trees that are not necessarily completely split up.
So let’s build a forest…
Calculating the anomaly score
We have a forest of trees. We have a data point. We want to see if it’s an anomaly. The paper gives the following definition for the score that we can calculate:
- s(x,n): Anomaly score for point x given subsample size n.
- E(h(x)): Average depth of element for all the trees in the forest.
- c(n) : Again our compensation function C (see above code section) with our subsample size as input.
This gives us the following code where we use 0.7 as our anomaly border value (you can play around with this value off course). If
- Every data record with an anomaly result value ≥ 0.7 is an anomaly
- Every data record with an anomaly result value < 0.7 is a normal point
Reading a Kafka topic
We want to have an efficient non blocking Kafka consumer that can take in batches of data and feed that information to our anomaly detection algorithm.
We use the ZIO effect system library to handle this in a very clean functional non blocking programming style. We only have 3 functions that are tied together for our main program loop.
- Kafka processing grouped within loop function: This is our main loop function that subscribes to a Kafka topic. It groups the records in batch sizes (and if it takes to long, then we just take what is present) that are passed to the process records function. After the processing we are committing the offset.
- Process records function: The main function that is fed with data and should decide what to do with an anomaly. In our below code we just print the result.
- Find json values function: We need to parse JSON in a fast way to retrieve the data that is needed for our algorithm
The isolation forest algorithm is a useful lightweight (low linear time-complexity and small memory requirement) multidimensional point anomaly detection algorithm.
It’s an unsupervised algorithm that removed the “normal cases” labeling and this is a very powerful feature in my opinion.
The article contains the main code snippets. You can find the full code in github here. We could have optimized the algorithm code with ZIO, but I wanted to separate ZIO from the algorithm code.
- Isolation forest paper for the pseudo-code from Fei Tony Liu, Kai Ming Ting and Zhi-Hua Zhou