classification_analysis.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import argparse
  2. import pandas as pd
  3. import seaborn as sns
  4. from sklearn import model_selection
  5. from sklearn.neighbors import KNeighborsClassifier
  6. from sklearn.metrics import classification_report
  7. parser = argparse.ArgumentParser(description="Analyze iris data")
  8. parser.add_argument('data', help="Input data (CSV) to process")
  9. parser.add_argument('output_figure', help="Output figure path")
  10. parser.add_argument('output_report', help="Output report path")
  11. args = parser.parse_args()
  12. # prepare the data as a pandas dataframe
  13. df = pd.read_csv(args.data)
  14. attributes = ["sepal_length", "sepal_width", "petal_length","petal_width", "class"]
  15. df.columns = attributes
  16. # create a pairplot to plot pairwise relationships in the dataset
  17. plot = sns.pairplot(df, hue='class', palette='muted')
  18. plot.savefig(args.output_figure)
  19. # perform a K-nearest-neighbours classification with scikit-learn
  20. # Step 1: split data in test and training dataset (20:80)
  21. array = df.values
  22. X = array[:,0:4]
  23. Y = array[:,4]
  24. test_size = 0.20
  25. seed = 7
  26. X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y,
  27. test_size=test_size,
  28. random_state=seed)
  29. # Step 2: Fit the model and make predictions on the test dataset
  30. knn = KNeighborsClassifier()
  31. knn.fit(X_train, Y_train)
  32. predictions = knn.predict(X_test)
  33. # Step 3: Save the classification report
  34. report = classification_report(Y_test, predictions, output_dict=True)
  35. df_report = pd.DataFrame(report).transpose().to_csv(args.output_report)