diff --git a/Makefile b/Makefile index 56a15801..081e80a6 100644 --- a/Makefile +++ b/Makefile @@ -86,16 +86,13 @@ install-develop: clean-build clean-pyc ## install the package in editable mode a .PHONY: lint lint: ## check style with flake8 and isort - find pipelines -name '*.json' | xargs -n1 -I{} bash -c "diff -q {} <(python -m json.tool {})" - find mlprimitives/jsons -name '*.json' | xargs -n1 -I{} bash -c "diff -q {} <(python -m json.tool {})" flake8 mlprimitives tests isort -c --recursive mlprimitives tests + find pipelines -name '*.json' | xargs -n1 -I{} bash -c "diff -q {} <(python -m json.tool {})" + find mlprimitives/jsons -name '*.json' | xargs -n1 -I{} bash -c "diff -q {} <(python -m json.tool {})" .PHONY: fix-lint fix-lint: ## fix lint issues using autoflake, autopep8, and isort - find pipelines -name '*.json' | xargs -n1 -I{} bash -c "python -m json.tool {} {}.tmp && mv {}.tmp {}" - find mlprimitives/jsons -name '*.json' | xargs -n1 -I{} bash -c "python -m json.tool {} {}.tmp && mv {}.tmp {}" - find mlprimitives -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables autopep8 --in-place --recursive --aggressive mlprimitives isort --apply --atomic --recursive mlprimitives @@ -104,6 +101,9 @@ fix-lint: ## fix lint issues using autoflake, autopep8, and isort autopep8 --in-place --recursive --aggressive tests isort --apply --atomic --recursive tests + find pipelines -name '*.json' | xargs -n1 -I{} bash -c "python -m json.tool {} {}.tmp && mv {}.tmp {}" + find mlprimitives/jsons -name '*.json' | xargs -n1 -I{} bash -c "python -m json.tool {} {}.tmp && mv {}.tmp {}" + # TEST TARGETS diff --git a/mlprimitives/adapters/keras.py b/mlprimitives/adapters/keras.py index cb210538..1c926def 100644 --- a/mlprimitives/adapters/keras.py +++ b/mlprimitives/adapters/keras.py @@ -50,8 +50,9 @@ def _build_model(self, **kwargs): return model - def __init__(self, layers, loss, optimizer, classification, - metrics=None, epochs=10, verbose=False, **hyperparameters): + def __init__(self, layers, loss, optimizer, classification, callbacks=tuple(), + metrics=None, epochs=10, verbose=False, validation_split=0, batch_size=32, + shuffle=True, **hyperparameters): self.layers = list() for layer in layers: @@ -67,6 +68,14 @@ def __init__(self, layers, loss, optimizer, classification, self.verbose = verbose self.classification = classification self.hyperparameters = hyperparameters + self.validation_split = validation_split + self.batch_size = batch_size + self.shuffle = shuffle + + for callback in callbacks: + callback['class'] = import_object(callback['class']) + + self.callbacks = callbacks def fit(self, X, y, **kwargs): self.model = self._build_model(**kwargs) @@ -74,10 +83,17 @@ def fit(self, X, y, **kwargs): if self.classification: y = keras.utils.to_categorical(y) - self.model.fit(X, y, epochs=self.epochs, verbose=self.verbose) + callbacks = [ + callback['class'](**callback.get('args', dict())) + for callback in self.callbacks + ] + + self.model.fit(X, y, epochs=self.epochs, verbose=self.verbose, callbacks=callbacks, + validation_split=self.validation_split, batch_size=self.batch_size, + shuffle=self.shuffle) def predict(self, X): - y = self.model.predict(X) + y = self.model.predict(X, batch_size=self.batch_size, verbose=self.verbose) if self.classification: y = np.argmax(y, axis=1) diff --git a/mlprimitives/cli.py b/mlprimitives/cli.py index e85a884e..378516ea 100644 --- a/mlprimitives/cli.py +++ b/mlprimitives/cli.py @@ -17,7 +17,7 @@ def _logging_setup(verbosity=1): logger = logging.getLogger() - log_level = (3 - verbosity) * 10 + log_level = (4 - verbosity) * 10 fmt = '%(asctime)s - %(levelname)s - %(message)s' formatter = logging.Formatter(fmt) logger.setLevel(log_level) diff --git a/mlprimitives/jsons/keras.Sequential.LSTMTextClassifier.json b/mlprimitives/jsons/keras.Sequential.LSTMTextClassifier.json index 2a2ea60e..b6fca59c 100644 --- a/mlprimitives/jsons/keras.Sequential.LSTMTextClassifier.json +++ b/mlprimitives/jsons/keras.Sequential.LSTMTextClassifier.json @@ -59,6 +59,10 @@ "type": "bool", "default": true }, + "verbose": { + "type": "bool", + "default": false + }, "conv_activation": { "type": "str", "default": "relu" @@ -85,6 +89,18 @@ "type": "int", "default": 10 }, + "callbacks": { + "type": "list", + "default": [] + }, + "validation_split": { + "type": "float", + "default": 0.0 + }, + "bastch_size": { + "type": "int", + "default": 32 + }, "layers": { "type": "list", "default": [ diff --git a/pipelines/keras.Sequential.LSTMTextClassifier.json b/pipelines/keras.Sequential.LSTMTextClassifier.json new file mode 100644 index 00000000..212e14f9 --- /dev/null +++ b/pipelines/keras.Sequential.LSTMTextClassifier.json @@ -0,0 +1,57 @@ +{ + "metadata": { + "name": "keras.Sequential.LSTMTextClassifier", + "data_type": "text", + "task_type": "classification" + }, + "validation": { + "dataset": "newsgroups" + }, + "primitives": [ + "mlprimitives.custom.counters.UniqueCounter", + "mlprimitives.custom.text.TextCleaner", + "mlprimitives.custom.counters.VocabularyCounter", + "keras.preprocessing.text.Tokenizer", + "keras.preprocessing.sequence.pad_sequences", + "keras.Sequential.LSTMTextClassifier" + ], + "input_names": { + "mlprimitives.custom.counters.UniqueCounter#1": { + "X": "y" + } + }, + "output_names": { + "mlprimitives.custom.counters.UniqueCounter#1": { + "counts": "classes" + }, + "mlprimitives.custom.counters.VocabularyCounter#1": { + "counts": "vocabulary_size" + } + }, + "init_params": { + "mlprimitives.custom.counters.VocabularyCounter#1": { + "add": 1 + }, + "mlprimitives.custom.text.TextCleaner#1": { + "language": "en" + }, + "keras.preprocessing.sequence.pad_sequences#1": { + "maxlen": 100 + }, + "keras.Sequential.LSTMTextClassifier#1": { + "input_length": 100, + "verbose": true, + "validation_split": 0.2, + "callbacks": [ + { + "class": "keras.callbacks.EarlyStopping", + "args": { + "monitor": "val_acc", + "patience": 1, + "min_delta": 0.01 + } + } + ] + } + } +}