DL-101-168-162 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. ### DVC
  2. $ cat << EOT >| src/train.py
  3. from joblib import dump
  4. from pathlib import Path
  5. import numpy as np
  6. import pandas as pd
  7. from skimage.io import imread_collection
  8. from skimage.transform import resize
  9. from sklearn.ensemble import RandomForestClassifier
  10. def load_images(data_frame, column_name):
  11. filelist = data_frame[column_name].to_list()
  12. image_list = imread_collection(filelist)
  13. return image_list
  14. def load_labels(data_frame, column_name):
  15. label_list = data_frame[column_name].to_list()
  16. return label_list
  17. def preprocess(image):
  18. resized = resize(image, (100, 100, 3))
  19. reshaped = resized.reshape((1, 30000))
  20. return reshaped
  21. def load_data(data_path):
  22. df = pd.read_csv(data_path)
  23. labels = load_labels(data_frame=df, column_name="label")
  24. raw_images = load_images(data_frame=df, column_name="filename")
  25. processed_images = [preprocess(image) for image in raw_images]
  26. data = np.concatenate(processed_images, axis=0)
  27. return data, labels
  28. def main(repo_path):
  29. train_csv_path = repo_path / "data/prepared/train.csv"
  30. train_data, labels = load_data(train_csv_path)
  31. rf = RandomForestClassifier()
  32. trained_model = rf.fit(train_data, labels)
  33. dump(trained_model, repo_path / "model/model.joblib")
  34. if __name__ == "__main__":
  35. repo_path = Path(__file__).parent.parent
  36. main(repo_path)
  37. EOT