| import argparse |
| import os |
| import importlib |
|
|
| def check_plugins(loaded_plugins): |
| print("Loaded plugins:") |
| for plugin in loaded_plugins: |
| print(f"- {plugin}") |
|
|
|
|
| def train_model(dataset_name, plugins): |
| dataset = {'train': []} |
|
|
| model = "FlowModel" |
|
|
| for plugin in plugins: |
| if hasattr(plugin, 'modify_model'): |
| model = plugin.modify_model(model) |
|
|
| for plugin in plugins: |
| if hasattr(plugin, 'on_train_start'): |
| plugin.on_train_start() |
|
|
| print(f"Training started on dataset: {dataset_name}") |
|
|
| for plugin in plugins: |
| if hasattr(plugin, 'on_train_end'): |
| plugin.on_train_end() |
|
|
| print("Training finished.") |
|
|
|
|
| def load_plugins(): |
| plugins_dir = './plugins' |
| plugins = [] |
|
|
| if not os.path.exists(plugins_dir): |
| os.makedirs(plugins_dir) |
| print(f"Plugins directory created at {plugins_dir}. Add your plugins there!") |
|
|
| for filename in os.listdir(plugins_dir): |
| if filename.endswith('.py') and filename != '__init__.py': |
| plugin_name = filename[:-3] |
| try: |
| plugin_module = importlib.import_module(f'plugins.{plugin_name}') |
| plugin_class = getattr(plugin_module, plugin_name.title().replace('_', ''), None) |
| if plugin_class: |
| plugins.append(plugin_class()) |
| print(f"Plugin {plugin_name} loaded.") |
| else: |
| print(f"No class found in plugin {plugin_name}.") |
| except Exception as e: |
| print(f"Failed to load plugin {plugin_name}: {e}") |
|
|
| return plugins |
|
|
|
|
| def predict_model(plugins): |
| print("Prediction started.") |
| for plugin in plugins: |
| if hasattr(plugin, 'on_predict'): |
| plugin.on_predict() |
| print("Prediction finished.") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="FlowModel CLI") |
| parser.add_argument('command', choices=['train', 'predict', 'check_plugins'], help="Command to run") |
|
|
| args = parser.parse_args() |
|
|
| plugins, loaded_plugins = load_plugins() |
|
|
| if args.command == 'train': |
| plugins = load_plugins() |
| train_model("mnist", plugins) |
| elif args.command == 'predict': |
| predict_model(plugins) |
| elif args.command == 'check_plugins': |
| check_plugins(loaded_plugins) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|