Trainer.java
package edu.odu.cs.cs350.acmClassifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.classifiers.functions.SMO;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import weka.classifiers.functions.supportVector.RBFKernel;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.commons.lang3.ObjectUtils.Null;
import org.netlib.util.floatW;
/**
* A Support Vector Machine (SVM) that trains on and classifies documents.
*/
public class Trainer implements Cloneable {
/**
* The internal learning machine, an SVM.
*/
private SMO svm;
/**
* The list of attributes that the learning machine uses.
*/
private ArrayList<Attribute> attrInfo;
/**
* The attribute that represents an ACM classification
*/
private Attribute classAttr;
/**
* The gamma parameter for the learning machine.
*/
private double gamma;
/**
* The C parameter for the learning machine.
*/
private double C;
/**
* Construct a Trainer from a given dictionary of words.
*
* @param dictionary the dictionary
*/
public Trainer(String[] dictionary) {
svm = new SMO();
attrInfo = new ArrayList<Attribute>();
for (String s: dictionary) {
attrInfo.add(new Attribute(s));
}
//Letters A-K represent an ACM classification. L represents no classification
String[] classNames = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L"};
ArrayList<String> classList = new ArrayList<String>(Arrays.asList(classNames));
//The classification attribute has a special name to not conflict with words in the dictionary
classAttr = new Attribute("$ATTR_classification", classList);
attrInfo.add(classAttr);
//Initialize gamma and C to default values
gamma = 0.01;
C = 1.00;
}
/**
* Copy constructor.
*
* param trainer the trainer to copy from
*/
public Trainer(Trainer trainer) {
// TODO
this.svm = trainer.svm;
this.attrInfo=trainer.attrInfo;
this.classAttr=trainer.classAttr;
this.C = trainer.C;
this.gamma= trainer.gamma;
}
@Override
public Trainer clone() {
return new Trainer(this);
}
/**
* Get a reference to the learning machine.
*
* @return the SVM
*/
public SMO svm() {
return svm;
}
/**
* Classify a document using the learning machine.
*
* @param doc the document to classify
* @return the document's classification
*/
public ACMClass classify(Document doc) {
Instances toBeClassified = new Instances("toBeClassified", attrInfo, 1);
toBeClassified.setClass(classAttr);
int instanceSize = attrInfo.size() - 1;
Instance toClassify = new DenseInstance(instanceSize);
toClassify.setDataset(toBeClassified);
for (int i = 0; i < instanceSize; i++) {
toClassify.setValue(i, doc.normalizedWordCounts.get(i));
}
double[] dist;
try {
dist = svm.distributionForInstance(toClassify);
} catch (Exception e) {
System.out.println("Exception occurred in Trainer.classify()");
System.out.println(e.toString());
return new ACMClass();
}
//If a probability is equal to or above the threshold, it will be considered a valid classification.
//The main exception is the null classification (L): because it doesn't make sense for the null
//classification to be paired with other classifications, null is only output here if its probability
//is greater than everything else.
double THRESHOLD = 1.0 / classAttr.numValues();
double nullProbability = dist[dist.length - 1];
boolean isNullHighest = true;
for (int i = 0; i < dist.length - 1; i++) {
if (nullProbability < dist[i]) {
isNullHighest = false;
break;
}
}
if (isNullHighest) {
return new ACMClass();
}
//Null was not found to be the highest probability, so create the classification string
StringBuffer classString = new StringBuffer();
for (int i = 0; i < dist.length - 1; i++) {
if (dist[i] >= THRESHOLD) {
classString.append((char)('A' + i));
}
}
return new ACMClass(classString.toString());
}
/**
* Train the learning machine with a set of documents.
*
* @param docs the set of documents to train on
*/
public void train(Document[] docs) {
//Calculate the size of the training set
int dataSize = 0;
for (Document d: docs) {
dataSize += Math.max(1, d.classification.size());
}
Instances training = new Instances("trainingData", attrInfo, dataSize);
training.setClass(classAttr);
int instanceSize = attrInfo.size();
//Construct the full training set
for (Document d: docs) {
int classCount = Math.max(1, d.classification.size());
for (int i = 0; i < classCount; i++) {
Instance instance = new DenseInstance(instanceSize);
instance.setDataset(training);
for (int j = 0; j < instanceSize - 1; j++) {
instance.setValue(j, d.normalizedWordCounts.get(j));
}
String letter;
if (d.classification.size() == 0) {
letter = "L";
} else {
letter = d.classification.getLetter(i);
}
instance.setValue(instanceSize - 1, letter);
training.add(instance);
}
}
//Create the kernel
svm.setC(C);
try {
svm.setKernel(new RBFKernel(training, 25007, gamma));
svm.buildClassifier(training);
} catch (Exception e) {
System.out.println("Exception occurred in Trainer.train()");
System.out.println(e.toString());
}
}
/**
* Load the learning machine's state from a file.
*
* @param filePath the file to read from
* @return the SMO object that was loaded from the file path
*/
public SMO loadState(String filePath)throws Exception{
//deserializing or read weka obj from file path and save/return smo obj
//new SMO obj, will be assigned model from file
File f = new File(filePath);
//System.out.println(f.exists()+" is : "+filePath);
try
{
FileInputStream fis = new FileInputStream(f);
ObjectInputStream ois = new ObjectInputStream(fis);
this.svm = (SMO) ois.readObject();
ois.close();
fis.close();
System.out.println("Loading SVM from file: "+filePath);
return this.svm;
}catch(Exception e) {}
//default is error case not valid file path return null
System.out.println("SVM object is empty");
return this.svm;
}
/**
* Save the learning machine's state to a file.
*
* @param filePath the file to write to
*
*
*/
public void saveState(String filePath) throws IOException
{
File f = new File(filePath);
//System.out.println(f.exists()+" is : "+filePath);
if (f.exists())
{
try (FileOutputStream fos = new FileOutputStream(f);
ObjectOutputStream oos = new ObjectOutputStream(fos))
{
//write svm object into file
oos.writeObject(this.svm);
//flush obj out stream and close file & obj stream
oos.flush();
oos.close();
fos.close();
System.out.println("Saving SVM object to : "+filePath);
}catch(IOException e) {
System.out.println("Stream Error");}
}
else{
System.out.println(filePath+" does not exist");
}
}
@Override
public String toString() {
// TODO
return "I am a Reader";
}
}