1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- ### DVC
- $ cat << EOT >| src/train.py
- from joblib import dump
- from pathlib import Path
- import numpy as np
- import pandas as pd
- from skimage.io import imread_collection
- from skimage.transform import resize
- from sklearn.ensemble import RandomForestClassifier
- def load_images(data_frame, column_name):
- filelist = data_frame[column_name].to_list()
- image_list = imread_collection(filelist)
- return image_list
- def load_labels(data_frame, column_name):
- label_list = data_frame[column_name].to_list()
- return label_list
- def preprocess(image):
- resized = resize(image, (100, 100, 3))
- reshaped = resized.reshape((1, 30000))
- return reshaped
- def load_data(data_path):
- df = pd.read_csv(data_path)
- labels = load_labels(data_frame=df, column_name="label")
- raw_images = load_images(data_frame=df, column_name="filename")
- processed_images = [preprocess(image) for image in raw_images]
- data = np.concatenate(processed_images, axis=0)
- return data, labels
- def main(repo_path):
- train_csv_path = repo_path / "data/prepared/train.csv"
- train_data, labels = load_data(train_csv_path)
- rf = RandomForestClassifier()
- trained_model = rf.fit(train_data, labels)
- dump(trained_model, repo_path / "model/model.joblib")
- if __name__ == "__main__":
- repo_path = Path(__file__).parent.parent
- main(repo_path)
- EOT
|