• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Xyclade/MachineLearning: Literature Study

原作者: [db:作者] 来自: 网络 收藏 邀请

开源软件名称(OpenSource Name):

Xyclade/MachineLearning

开源软件地址(OpenSource Url):

https://github.com/Xyclade/MachineLearning

开源编程语言(OpenSource Language):

Scala 100.0%

开源软件介绍(OpenSource Introduction):

Most developers these days have heard of machine learning, but when trying to find an 'easy' way into this technique, most people find themselves getting scared off by the abstractness of the concept of Machine Learning and terms as regression, unsupervised learning, Probability Density Function and many other definitions. If one switches to books there are books such as An Introduction to Statistical Learning with Applications in R and Machine Learning for Hackers who use programming language R for their examples.

However R is not really a programming language in which one writes programs for everyday use such as is done with for example Java, C#, Scala etc. This is why in this blog machine learning will be introduced using Smile, a machine learning library that can be used both in Java and Scala. These are languages that most developers have seen at least once during their study or career.

The first section 'The global idea of machine learning' contains all important concepts and notions you need to know about to get started with the practical examples that are described in the section 'Practical Examples'. The section practical examples is inspired by the examples from the book Machine Learning for Hackers. Additionally the book Machine Learning in Action was used for validation purposes.

The second section Practical examples contains examples for various machine learning (ML) applications, with Smile as ML library.

Note that in this blog, 'new' definitions are hyperlinked such that if you want, you can read more regarding that specific topic, but you are not obliged to do this in order to be able to work through the examples.

As final note I'd like to thank the following people:

  • Haifeng Li for his support and writing the awesome and free to use library Smile.
  • Erik Meijer for all suggestions and supervision of the process of writing this blog.
  • Richard van Heest for his feedback and co-reading the blog.
  • Lars Willems for his feedback and co-reading the blog.

#The global idea of machine learning You probably have heard about Machine learning as a concept one time or another. However, if you would have to explain what machine learning is to another person, how would you do this? Think about this for a second before reading the rest of this section.

Machine learning is explained in many ways, some more accurate than others, however there is a lot of inconsistency in its definition. Some say machine learning is generating a static model based on historical data, which then allows you to predict for future data. Others say it's a dynamic model that keeps on changing as more data is added over time.

I agree more with the dynamic definition but due to certain limitations we explain the static model method in the examples. However, we do explain how the dynamic principle would work in the subsection Dynamic machine learning.

The upcoming subsections explain commonly used definitions and notions in the machine learning field. We advise you to read through these before starting the practical examples.

##Features A feature is a property on which a model is trained. Say for example that you classify emails as spam/ham based on the frequency of the word 'Buy' and 'Money'. Then these words are features, or part of a feature if you would combine it with more words. If you would use machine learning to predict whether one is a friend of yours, the amount of 'common' friends could be a feature. Note that in the field, sometimes features are also referred to as attributes.

##Model When one talks about machine learning, often the term model is mentioned. The model is the result of any machine learning method and the algorithm used within this method. This model can be used to make predictions in supervised, or to retrieve clusterings in unsupervised learning. Chances are high that you will encounter the terms online and offline training of a model in the field. The idea behind online training is that you add training data to an already existing model, whereas with offline training you generate a new model from scratch. For performance reasons, online training would be the most preferable method. However for some algorithms this is not possible.

##Learning methods In the field of machine learning there are two leading ways of learning, namely Supervised learning and Unsupervised learning. A brief introduction is necessary when you want to use Machine learning in your applications, as picking the right machine learning approach and algorithm is an important but sometimes also a little tedious process.

###Supervised learning In supervised learning you define explicitly what features you want to use, and what output you expect. An example is predicting gender based on height and weight, which is known as a Classification problem. Additionally you can also predict absolute values with Regression. An example of regression with the same data would be predicting one's length based on gender and weight. Some supervised algorithms can only be used for either classification or regression, such as K-NN. However there also exists algorithms such as Support Vector Machines which can be used for both purposes.

####Classification The problem of classification within the domain of Supervised learning is relatively simple. Given a set of labels, and some data that already received the correct labels, we want to be able to predict labels for new data that we did not label yet. However, before thinking of your data as a classification problem, you should look at what the data looks like. If there is a clear structure in the data such that you can easily draw a regression line it might be better to use a regression algorithm instead. Given the data does not fit to a regression line, or when performance becomes an issue, classification is a good alternative.

An example of a classification problem would be to classify emails as ham or spam based on their content. Given a training set in which emails are labeled ham or spam, a classification algorithm can be used to train a Model. This model can then be used to predict for future emails whether they are ham or spam. A typical example of a classification algorithm is the K-NN algorithm. Another commonly used example of a classification problem is Classifying email as spam or ham which is also one of the examples written on this blog.

####Regression Regression is a lot stronger in comparison to classification. This is because in regression you are predicting actual values, rather than labels. Let us clarify this with a short example: given a table of weights, heights, and genders, you can use K-NN to predict one's gender when given a weight and height. With this same dataset using regression, you could instead predict one's weight or height, given the gender and the respective other missing parameter.

With this extra power, comes great responsibility, thus in the working field of regression one should be very careful when generating the model. Common pitfalls are overfitting, underfitting and not taking into account how the model handles extrapolation and interpolation.

###Unsupervised learning In contrast to supervised, with unsupervised learning you do not exactly know the output on beforehand. The idea when applying unsupervised learning is to find hidden underlying structure in a dataset. An example would be PCA with which you can reduce the amount of features by combining features. This combining is done based on the possibly hidden correlation between these features. Another example of unsupervised learning is K-means clustering. The idea behind K-means clustering is to find groups within a dataset, such that these groups can later on be used for purposes such as supervised learning.

####Principal Components Analysis (PCA) Principal Components Analysis is a technique used in statistics to convert a set of correlated columns into a smaller set of uncorrelated columns, reducing the amount of features of a problem. This smaller set of columns are called the principal components. This technique is mostly used in exploratory data analysis as it reveals internal structure in the data that can not be found with eye-balling the data.

A big weakness of PCA however are outliers in the data. These heavily influence its result, thus looking at the data on beforehand, eliminating large outliers can greatly improve its performance.

To give a clear idea of what PCA does, we show the plots of a dataset of points with 2 dimensions in comparison to the same dataset plotted after PCA is applied.

On the left plot the original data is shown, where each color represents a different class. It is clear that from the 2 dimensions (X and Y) you could reduce to 1 dimension and still classify properly. This is where PCA comes in place. With PCA a new value is calculated for each datapoint, based on its original dimensions.

In the plot on the right you see the result of applying PCA to this data. Note that there is a y value, but this is purely to be able to plot the data and show it to you. This Y value is 0 for all values as only the X values are returned by the PCA algorithm. Also note that the values for X in the right plot do not correspond to the values in the left plot, this shows that PCA not 'just drops' a dimension.

##Validation techniques In this section we will explain some of the techniques available for model validation, and some terms that are commonly used in the Machine Learning field within the scope of validation techniques.

###Cross validation The technique of cross validation is one of the most common techniques in the field of machine learning. Its essence is to ignore part of your dataset while training your model, and then using the model to predict this ignored data. Comparing the predictions to the actual value then gives an indication of the performance of your model, and the quality of your training data.

The most important part of this cross validation is the splitting of data. You should always use the complete dataset when performing this technique. In other words you should not randomly select X datapoints for training and then randomly select X datapoints for testing, because then some datapoints can be in both sets while others might not be used at all.

####(2 fold) Cross validation In 2-fold cross validation you perform a split of the data into test and training for each fold (so 2 times) and train a model using the training dataset, followed by verification with the testing set. Doing so allows you to compute the error in the predictions for the test data 2 times. These error values then should not differ significantly. If they do, either something is wrong with your data or with the features you selected for model prediction. Either way you should look into the data more and find out what is happening for your specific case, as training a model based on the data might result in an overfitted model for erroneous data.

###Regularization The basic idea of regularization is preventing overfitting your model by simplifying it. Suppose your data is a 3rd degree polynomial function, but your data has noise and this would cause the model to be of a higher degree. Then the model would perform poorly on new data, whereas it seems to be a good model at first. Regularization helps preventing this, by simplifying the model with a certain value lambda. However to find the right lambda for a model is hard when you have no idea as to when the model is overfitted or not. This is why cross validation is often used to find the best lambda fitting your model.

Precision

In the field of computer science we use the term precision to define the amount of items selected which are actually relevant. So when you compute the precision value for a search algorithm on documents, the precision of that algorithm is defined by how many documents from the result set are actually relevant.

This value is computed as follows:

As this might be a bit hard to grasp I will give an example:

Say we have documents {aa, ab, bc, bd, ee} as the complete corpus, and we query for documents with a in the name. If our algorithm would then return the document set {aa, ab}, the precision would be 100% intuitively. Let's verify it by filling in the formula:

Indeed it is 100%. If we would run the query again but get more results than only {aa, ab}, say we additionally get {bc,de} back as well, this influences the precision as follows:

Here the results contained the relevant results but also 2 irrelevant results. This caused the precision to decrease. However if you would calculate the recall for this example it would be 100%. This is how precision and recall differ from each other.

Recall

Recall is used to define the amount of relevant items that are retrieved by an algorithm given a query and a data corpus. So given a set of documents, and a query that should return a subset of those documents, the recall value represents how many of the relevant documents are actually returned. This value is computed as follows:

Given this formula, let's do an example to see how it works:

Say we have documents {aa,ab,bc,bd,ee} as complete corpus and we query for documents with a in the name. If our algorithm would be to return {aa,ab} the recall would be 100% intuitively. Let's verify it by filling in the formula:

Indeed it is 100%. Next we shall show what happens if not all relevant results are returned:

Here the results only contained half of the relevant results. This caused the recall to decrease. However if you would compute the precision for this situation, it would result in 100% precision, because all results are relevant.

Prior

The prior value that belongs to a classifier given a datapoint represents the likelihood that this datapoint belongs to this classifier. In practice this means that when you get a prediction for a datapoint, the prior value that is given with it, represents how 'convinced' the model is regarding the classification given to that datapoint.

Root Mean Squared Error (RMSE)

The Root Mean Squared Error (RMSE or RMSD where D stands for deviation) is the square root of the mean of the squared differences between the actual value and predicted value. As this is might be hard to grasp, I'll explain it using an example. Suppose we have the following values:

Predicted temperature Actual temperature squared difference for Model square difference for average
10 12 4 7.1111
20 17 9 5.4444
15 15 0 0.1111

The mean of these squared differences for the model is 4.33333, and the root of this is 2.081666. So in average, the model predicts the values with an error of 2.08. The lower this RMSE value is, the better the model is in its predictions. This is why in the field, when selecting features, one computes the RMSE with and without a certain feature, in order to say something about how that feature affects the performance of the model. With this information one can then decide whether the additional computation time for this feature is worth it in comparison to the improvement rate on the model.

Additionally, because the RMSE is an absolute value, it can be normalised in order to compare models. This results in the Normalised Root Mean Square Error (NRMSE). For computing this however, you need to know the minimum and maximum value that the system can contain. Let's suppose we can have temperatures ranging from minimum of 5 to a maximum of 25 degrees, then computing the NRMSE is as follows:

When we fill in the actual values we get the following result:

Now what is this 10.45 value? This is the error percentage the model has in average on its datapoints.

Finally we can use RMSE to compute a value that is known in the field as R Squared. This value represents how good the model performs in comparison to ignoring the model and just taking the average for each value. For that you need to calculate the RMSE for the average first. This is 4.22222 (taking the mean of the values from the last column in the table), and the root is then 2.054805. The first thing you should notice is that this value is lower than that of the model. This is not a good sign, because this means the model performs worse than just taking the mean. However to demonstrate how to compute R Squared we will continue the computations.

We now have the RMSE for both the model and the mean, and then computing how well the model performs in comparison to the mean is done as follows:

This results in the following computation:

Now what does this -1.307229 represent? Basically it says that the model that predicted these values performs about 1.31 percent worse than returning the average each time a value is to be predicted. In other words, we could better use the average function as a predictor rather than the model in this specific case.

##Common pitfalls This section describes some common pitfalls in applying machine learning techniques. The idea of this section is to make you aware of these pitfalls and help you prevent actually walking into one yourself.

###Overfitting When fitting a function on the data, there is a possibility the data contains noise (for example by measurement errors). If you fit every point from the data exactly, you incorporate this noise into the model. This causes the model to predict really well on your test data, but relatively poor on future data.

The left image here below shows how this overfitting would look like if you were to plot your data and the fitted functions, whereas the right image would represent a good fit of the regression line through the datapoints.

Overfitting can easily happen when applying regression but can just as easily be introduced in Naive Bayes classifications. In regression it happens with rounding, bad measurements and noisy data. In naive bayes however, it could be the features that were picked. An example for this would be classifying spam or ham while keeping all stop words.

Overfitting can be detected by performing validation techniques and looking into your data's statistical features, and detecting and removing outliers.

Underfitting

When you are turning your data into a model, but are leaving (a lot of) statistical data behind, this is called underfitting. This can happen due to various reasons, such as using a wrong regression type on the data. If you have a non-linear structure in the data, and you apply linear regression, this would result in an under-fitted model. The left image here below represents an under-fitted regression line whereas the right image shows a good fit regression line.

You can prevent underfitting by plotting the data to get insights in the underlying structure, and using validation techniques such as cross validation.

###Curse of dimensionality The curse of dimensionality is a collection of problems that can occur when your data size is lower than the amount of features (dimensions) you are trying to use to create your machine learning model. An example of a dimensionality curse is matrix rank deficiency. When using Ordinary Least Squares(OLS), the underlying algorithm solves a linear system in order to build up a model. However if you have more columns than you have rows, coming up with a single solution for this system is not possible. If this is the case, the best solution would be to get more datapoints or reduce the feature set.

If you want to know more regarding this curse of dimensionality, there is a study focussed on this issue. In this study, researchers Haifeng Li, Keshu Zhang and Tao Jiang developed an algorithm that improves cancer classification with very few datapoints. They compared their algorithm with support vector machines and random forests.

##Dynamic machine learning In almost all literature you can find about machine learning, a static model is generated and validated, and then used for predictions or recommendations. However in practice, this alone would not make a very good machine learning application. This is why in this section we will explain how to turn a static model into a dynamic model. Since the (most optimal) implementation depends on the algorithm you are using, we will explain the concept rather than giving a practical example. Because explaining it in text only will not be very clear we first present you the whole system in a diagram. We will then use this diagram to explain machine learning and how to make the system dynamic.

The basic idea of machine learning can be described by the following steps:

  1. Gather data
  2. Split the data into a testing and training set
  3. Train a model (with help of a machine learning algorithm)
  4. Validate the model with a validation method which takes the model and testing data
  5. do predictions based on the model

In this process there are a few steps missing when it comes to actual applications in the field. These steps are in my opinion the most important steps to make a system actually learn.

The idea behind what we call dynamic machine learning is as follows: You take your predictions, combine it with user feedback and feed it back into your system to improve your dataset and model. As we just said we need user feedback, so how is this gained? Let's take friend suggestions on Facebook for example. The user is presented 2 options: 'Add Friend' or 'Remove'. Based on the decision of the user, you have direct feedback regarding that prediction.

So say you have this user feedback, then you can apply machine learning over your machine learning application to learn about the feedback that is given. This might sound a bit strange, but we will try to explain this more elaborately. However before we do this, we need to make a disclaimer: our description of the Facebook friend suggestion system is a 100% assumption and in no way confirmed by Facebook itself. Their systems are a secret to the outside world as far as we know.

Say the system predicts based on the following features:

  1. amount of common friends
  2. Same hometown
  3. Same age

Then you can compute a prior for every person on Facebook regarding the chance that he/she is a good suggestion to be your friend. Say you store the results of all these predictions for a period of time, then analysing this data on its own with machine learning allows you to improve your system. To elaborate on this, say most of our 'removed' suggestions had a high rating on feature 2, but relatively low on 1, then we can add weights to the prediction system such that feature 1 is more important than feature 2. This will then improve the recommendation system for us.

Additionally, the dataset grows over time, so we should keep on updating our model with the new data to make the predictions more accurate. How to do this however, depends on the size and mutation rate of your data.

#Practical examples In this section we present you a set of machine learning algorithms in a practical setting. The idea of these examples is to get you started with machine learning algorithms without an in depth explanation of the underlying algorithms. We focus purely on the functional aspect of these algorithms, how you can verify your implementation and finally try to make you aware of common pitfalls.

The following examples are available:

For each of these examples we used the Smile Machine Learning library. We used both the smile-core and smile-plot libraries. These libraries are available on Maven, Gradle, Ivy, SBT and Leiningen. Information on how to add them using one of these systems can be found here for the core, and here for the plotting library.

So before you start working through an example, I assume you made a new project in your favourite IDE, and added the smile-core and smile-plot libraries to your project. Additional libraries when used, and how to get the example data is addressed per example.

##Labeling ISPs based on their down/upload speed (K-NN using Smile in Scala)

The goal of this section is to use the K-Nearest Neighbours (K-NN) Algorithm to classify download/upload speed pairs as internet service provider (ISP) Alpha (represented by 0) or Beta (represented by 1). The idea behind K-NN is as follows: given a set of points that are classified, you can classify the new point by looking at its K neighbours (K being a positive integer). The idea is that you find the K-neighbours by looking at the euclidean distance between the new point and its surrounding points. For these neighbours you then look at the biggest representative class and assign that class to the new point.

To start this example you should download the example data. Additionally you should set the path in the code snippet to where you stored this example data.

The first step is to load the CSV data file. As this is no rocket science, I provide the code for this without further explanation:

object KNNExample {
   def main(args: Array[String]): Unit = {
    val basePath = "/.../KNN_Example_1.csv"
    val testData = getDataFromCSV(new File(basePath))    
    }
    
  def getDataFromCSV(file: File): (Array[Array[Double]], Array[Int]) = {
    val source = scala.io.Source.fromFile(file)
    val data = source
    	.getLines()
    	.drop(1)
    	.map(x => getDataFromString(x))
    	.toArray
    	
    source.close()
    val dataPoints = data.map(x => x._1)
    val classifierArray = data.map(x => x._2)
    return (dataPoints, classifierArray)        
  }
  
  def getDataFromString(dataString: String): (Array[Double], Int) = {

    //Split the comma separated value string into an array of strings
    val dataArray: Array[String] = dataString.split(',')

    //Extract the values from the strings
    val xCoordinate: Double = dataArray(0).toDouble
    val yCoordinate: Double = dataArray(1).toDouble
    val classifier: Int = dataArray(2).toInt

    //And return the result in a format that can later 
    //easily be used to feed to Smile
    return (Array(xCoordinate, yCoordinate), classifier)
  }
}

First thing you might wonder now is why is the data formatted this way. Well, the separation between dataPoints and their label values is for easy splitting between testing and training data, and because the API expects the data this way for both executing the K-NN algorithm and plotting the data. Secondly the datapoints stored as an Array[Array[Double]] is done to support datapoints in more than just 2 dimensions.

Given the data the first thing to do next is to see what the data looks like. For this Smile provides a nice plotting library. In order to use this however, the application should be changed to a Swing application. Additionally the data should be fed to the plotting library to get a JPane with the actual plot. After changing your code it should look like this:

object KNNExample extends SimpleSwingApplication {
 def top = new MainFrame {
   title = "KNN Example"
   val basePath = "/.../KNN_Example_1.csv"

   val testData = getDataFromCSV(new File(basePath))

   val plot = ScatterPlot.plot(testData._1,
   							 testData._2, 
   							 '@', 
   							 Array(Color.red, Color.blue)
   							)
   peer.setContentPane(plot)
   size = new Dimension(400, 400)

 }
 ...

The idea behind plotting the data is to verify whether K-NN is a fitting Machine Learning algorithm for this specific set of data. In this case the data looks as follows:

In this plot you can see that the blue and red points seem to be mixed in the area (3 < x < 5) and (5 < y < 7.5). Since the groups are mixed the K-NN algorithm is a good choice, as fitting a linear decision boundary would cause a lot of false classifications in the mixed area.

Given this choice to use the K-NN algorithm to be a good fit for this problem, let's continue with the actual Machine Learning part. For this the GUI is ditched since it does not really add any value. Recall from the section The global idea of Machine Learning that in machine learning there are 2 key parts: Prediction and Validation. First we will look at the validation, as using a model without any validation is never a good idea. The main reason to validate the model here is to prevent overfitting. However, before we even can do validation, a correct K should be chosen.

The drawback is that there is no golden rule for finding the correct K. However, finding a good K that allows for most datapoints to be classified correctly can be done by looking at the data. Additionally the K should be picked carefully to prevent undecidability by the algorithm. Say for example K=2, and the problem has 2 labels, then when a point is between both labels, which one should the algorithm pick. There is a rule of thumb that K should be the square root of the number of features (on other words the number of dimensions). In our example this would be K=1, but this is not really a good idea since this would lead to higher false-classifications around decision boundaries. Picking K=2 would result in the error regarding our two labels, thus picking K=3 seems like a good fit for now.

For this example we do 2-fold Cross Validation. In general 2-fold cross validation is a rather weak method of model Validation, as it splits the dataset in half and only validates twice, which still allows for overfitting, but since the dataset is only 100 points, 10-fold (which is a stronger version) does not make sense, since then there would only be 10 datapoints used for testing, which would give a skewed error rate.

  def main(args: Array[String]): Unit = {
   val basePath = "/.../KNN_Example_1.csv"
    val testData = getDataFromCSV(new File(basePath))

    //Define the amount of rounds, in our case 2 and 
    //initialise the cross validation
    val cv = new CrossValidation(testData._2.length, validationRounds)

    val testDataWithIndices = (testData
    							._1
    							.zipWithIndex, 
    							testData
    							._2
    							.zipWithIndex)

    val trainingDPSets = cv.train
      .map(indexList => indexList
      .map(index => testDataWithIndices
      ._1.collectFirst { case (dp, `index`) => dp}.get))

    val trainingClassifierSets = cv.train
      .map(indexList => indexList
      .map(index => testDataWithIndices
      ._2.collectFirst { case (dp, `index`) => dp}.get))

    val testingDPSets = cv.test
      .map(indexList => indexList
      .map(index => testDataWithIndices
      ._1.collectFirst { case (dp, `index`) => dp}.get))

    val testingClassifierSets = cv.test
      .map(indexList => indexList
      .map(index => testDataWithIndices
      ._2.collectFirst { case (dp, `index`) => dp}.get))


    val validationRoundRecords = trainingDPSets
      .zipWithIndex.map(x => (	x._1,			
      							trainingClassifierSets(x._2),
      							testingDPSets(x._2),
      							testingClassifierSets(x._2)
      						   )
      					)

    validationRoundRecords
    	.foreach { record =>

      val knn = KNN.learn(record._1, record._2, 3)

      //And for each test data point make a prediction with the model
      val predictions = record
      	._3
      	.map(x => knn.predict(x))
		.zipWithIndex

      //Finally evaluate the predictions as correct or incorrect
      //and count the amount of wrongly classified data points.
      
      val error = predictions
      	.map(x => if (x._1 != record._4(x._2)) 1 else 0)
      	.sum

      println("False prediction rate: " + error / predictions.length * 100 + "%")
    }
  }

If you execute this code several times you might notice the false prediction rate to fluctuate quite a bit. This is due to the random samples taken for training and testing. When this random sample is taken a bit unfortunate, the error rate becomes much higher while when taking a good random sample, the error rate could be extremely low.

Unfortunately I cannot provide you with a golden rule to when your model was trained with the best possible training set. One would say that the model with the least error rate is always the best, but when you recall the term overfitting, picking this particular model might perform really bad on future data. This is why having a large enough and representative dataset is key to a good Machine Learning application. However, when aware of this issue, you could implement manners to keep updating your model based on new data and known correct classifications.

Let's recap what we've done so far. First you took care of getting the training and testing data. Next up you generated and validated several models and picked the model which gave the best results. Then we now have one final step to do, which is making predictions using this model:

val knn = KNN.learn(record._1, record._2, 3)
val unknownDataPoint = Array(5.3, 4.3)
val result = knn.predict(unknownDatapoint)
if (result == 0)
{
	println("Internet Service Provider Alpha")
}
else if (result == 1)
{
	println("Internet Service Provider Beta")
}
else
{
	println("Unexpected prediction")
}

The result of executing this code is labeling the unknownDataPoint (5.3, 4.3) as ISP Alpha. This is one of the easier points to classify as it is clearly in the Alpha field of the datapoints in the plot. As it is now clear how to do these predictions I won't present you with other points, but feel free to try out how different points get predicted.

##Classifying email as spam or ham (Naive Bayes) In this example we will be using the Naive Bayes algorithm to classify email as ham (good emails) or spam (bad emails) based on their content. The Naive Bayes algorithm calculates the probability for an object for each possible class, and then returns the class with the highest probability. For this probability calculation the algorithm uses features. The reason it's called Naive Bayes is because it does not incorporate any correlation between features. In other words, each feature counts the same. I'll explain a bit more using an example:

Say you are classifying fruits and vegetables based on the features color, diameter and shape and you have the following classes: apple, tomato, and cranberry.

Suppose you then want to classify an object with the following values for the features: (red,4 cm, round). This would obviously be a tomato for us, as it is way to small to be an apple, and too large for a cranberry. However, because the Naive Bayes algorithm evaluates each feature individually it will classify it as follows:

  • Apple 66.6% probable (based on color, and shape)
  • Tomato 100.0% probable (based on color, shape and size)
  • cranberry 66.6% probable (based on color and shape)

Thus even though it seems really obvious that it can't be a cranberry or apple, Naive Bayes still gives it a 66.6% change of being either one. So even though it classifies the tomato correctly, it can give poor results in edge cases where the size is just outside the scope of the training set. However, for spam classification Naive Bayes works well, as spam or ham cannot be classified purely based on one feature (word).

As you should now have an idea on how the Naive Bayes algorithm works, we can continue with the actual example. For this example we will use the Naive Bayes implementation from Smile in Scala to classify emails as spam or ham based on their content.

Before we can start however, you should download the data for this example from the SpamAssasins public corpus. The data you need for the example is the easy_ham and spam files, but the rest can also be used in case you feel like experimenting some more. You should unzip these files and adjust the file paths in the code snippets to match the location of the folders. Additionally you will need the stop words file for filtering purposes.

As with every machine learning implementation, the first step is to load in the training data. However in this example we are taking it 1 step further into machine learning. In the KNN examples we had the download and upload speed as features. We did not refer to them as features, as they were the only properties available. For spam classification it is not completely trivial what to use as features. One can use the Sender, the subject, the message content, and even the time of sending as features for classifying as spam or ham.

In this example we will use the content of the email as feature. By this we mean we will select the features (words in this case) from the bodies of the emails in the training set. In order to be able to do this, we need to build a Term Document Matrix (TDM).

We will start off with writing the functions for loading the example data. This will be done with a getMessage method which gets a filtered body from an email given a File as parameter.

def getMessage(file : File)  : String  =
  {
    // 

鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap