def save(obj: Any, file: PathStr, *args: List[Any], **kwargs: Dict[str, Any]) -> File:
r"""
Save any file with supported extensions.
"""
extension = os.path.splitext(file)[-1].lower()[1:]
if extension in PYTORCH:
if not TORCH_AVAILABLE:
raise ImportError(f"Trying to save {obj} to {file!r} but torch is not installed.")
torch.save(obj, file, *args, **kwargs)
elif extension in NUMPY:
if not NUMPY_AVAILABLE:
raise ImportError(f"Trying to save {obj} to {file!r} but numpy is not installed.")
numpy.save(file, obj, *args, **kwargs)
elif extension in PANDAS:
if not PANDAS_AVAILABLE:
raise ImportError(f"Trying to save {obj} to {file!r} but pandas is not installed.")
pandas.to_pickle(obj, file, *args, **kwargs)
elif extension in PARQUET:
if isinstance(obj, pandas.DataFrame):
obj.to_parquet(file, *args, **kwargs)
elif not PYARROW_AVAILABLE:
raise ImportError(f"Trying to save {obj} to {file!r} but pyarrow is not installed.")
else:
pyarrow.parquet.write_table(obj, file, *args, **kwargs)
elif extension in CSV:
if isinstance(obj, pandas.DataFrame):
obj.to_csv(file, *args, **kwargs)
else:
raise NotImplementedError(f"Trying to save {obj} to {file!r} but is not supported")
elif extension in JSON:
if isinstance(obj, FlatDict):
obj.json(file)
else:
with open(file, "w") as fp:
json.dump(obj, fp, *args, **kwargs) # type: ignore
elif extension in YAML:
if isinstance(obj, FlatDict):
obj.yaml(file)
else:
with open(file, "w") as fp:
yaml.dump(obj, fp, *args, **kwargs) # type: ignore
elif extension in PICKLE:
with open(file, "wb") as fp:
pickle.dump(obj, fp, *args, **kwargs) # type: ignore
else:
raise ValueError(f"Tying to save {obj} to {file!r} with unsupported extension={extension!r}")
return file