sparse_caption.utils package
Submodules
sparse_caption.utils.config module
Created on Thu Jun 22 14:48:09 2017
@author: jiahuei
- class sparse_caption.utils.config.Config(x: Optional[str] = None, **kwargs)
Bases:
objectConfiguration object.
- check_loaded_version(loaded_version)
- compat()
- deepcopy()
- dict()
- get(key, default_value)
- json(**kwargs)
- classmethod load_config_json(config_filepath, verbose=True)
- save_config(exist_ok=True)
Save this instance to a json file.
- Parameters
exist_ok (
bool) – If set to True, allow overwrites.
- update(kv_mapping)
sparse_caption.utils.file module
Utilities for file download and caching.
- sparse_caption.utils.file.dump_json(path, data, utf8=True, lf_newline=True, **json_kwargs)
- sparse_caption.utils.file.dumps_file(path, string, utf8=True, lf_newline=True)
- sparse_caption.utils.file.file_size(path, suffix='B')
- sparse_caption.utils.file.get_file(fname, origin, dest_dir, md5_hash=None, file_hash=None, hash_algorithm='auto', extract=False, archive_format='auto')
Downloads a file from a URL if it not already in the cache.
The file at the url origin is downloaded to the dest_dir, and given the filename fname. The final location of a file example.txt would therefore be dest_dir/example.txt.
Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. Passing a hash will verify the file after download. The command line programs shasum and sha256sum can compute the hash.
- Parameters
fname – Name of the file. If an absolute path /path/to/file.txt is specified the file will be saved at that location.
origin – Original URL of the file.
dest_dir – Location to store the files.
md5_hash – Deprecated in favor of ‘file_hash’. md5 hash of the file for verification
file_hash – The expected hash string of the file after download. The sha256 and md5 hash algorithms are both supported.
hash_algorithm – Select the hash algorithm to verify the file. options are ‘md5’, ‘sha256’, and ‘auto’. The default ‘auto’ detects the hash algorithm in use.
extract – True tries extracting the file as an Archive, like tar or zip.
archive_format – Archive format to try for extracting the file. Options are ‘auto’, ‘tar’, ‘zip’, and None. ‘tar’ includes tar, tar.gz, and tar.bz files. The default ‘auto’ is [‘tar’, ‘zip’]. None or an empty list will return no matches found.
- Returns
Path to the downloaded file
- sparse_caption.utils.file.list_dir(path)
- sparse_caption.utils.file.list_files(path)
- sparse_caption.utils.file.load_pil_image_from_url(url)
- sparse_caption.utils.file.read_file(path)
- sparse_caption.utils.file.read_json(path, utf8=True)
- sparse_caption.utils.file.tqdm_hook(t)
Wraps tqdm instance. Don’t forget to close() or __exit__() the tqdm instance once you’re done with it (easiest using with syntax). .. rubric:: Example
- with tqdm(…) as t:
reporthook = tqdm_hook(t) urllib.urlretrieve(…, reporthook=reporthook)
- sparse_caption.utils.file.validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535)
Validates a file against a sha256 or md5 hash.
- Parameters
fpath – path to the file being validated
file_hash – The expected hash string of the file. The sha256 and md5 hash algorithms are both supported.
algorithm – Hash algorithm, one of ‘auto’, ‘sha256’, or ‘md5’. The default ‘auto’ detects the hash algorithm in use.
chunk_size – Bytes to read at a time, important for large files.
- Returns
Whether the file is valid
- sparse_caption.utils.file.zip_dir(target_dir, save_dir)
sparse_caption.utils.losses module
Created on 16 Sep 2020 15:00:29 @author: jiahuei
- class sparse_caption.utils.losses.LabelSmoothing(size=0, padding_idx=0, smoothing=0.0)
Bases:
torch.nn.modules.module.ModuleImplement label smoothing.
- forward(input, target, mask)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class sparse_caption.utils.losses.LanguageModelCriterion
Bases:
torch.nn.modules.module.Module- forward(input, target, mask)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class sparse_caption.utils.losses.RewardCriterion
Bases:
torch.nn.modules.module.Module- forward(input, mask, reward)
- This function computes
log(y_t) * reward * mask_t (where mask_t zeroes out non-words in the sequence)
- given
input = predicted probability sequence = predicted word index reward = …
- training: bool
sparse_caption.utils.misc module
Created on 29 Apr 2020 15:55:21 @author: jiahuei
Utility functions.
- class sparse_caption.utils.misc.ChoiceList(choices)
Bases:
objectA Type for ArgParse for validation of choices. https://mail.python.org/pipermail/tutor/2011-April/082825.html
- sparse_caption.utils.misc.configure_logging(logging_level: Union[int, str] = 20, logging_fmt: str = '%(levelname)s: %(name)s: %(funcName)s: %(message)s', logger_obj: Union[None, logging.Logger] = None) logging.Logger
Setup logging on the root logger, because transformers calls logger.info upon import.
- Adapted from:
https://stackoverflow.com/a/54366471/5825811 Configures a simple console logger with the given level. A use-case is to change the formatting of the default handler of the root logger.
- Format variables:
https://docs.python.org/3/library/logging.html#logrecord-attributes
- sparse_caption.utils.misc.csv_to_float_list(input_string: str)
- sparse_caption.utils.misc.csv_to_int_list(input_string: str)
- sparse_caption.utils.misc.csv_to_str_list(input_string: str)
- sparse_caption.utils.misc.get_memory_info()
Get node total memory and memory usage https://stackoverflow.com/a/17718729
- sparse_caption.utils.misc.grouper(iterable, group_n, fill_value=0)
- sparse_caption.utils.misc.humanise_number(size: Union[float, int], suffix: str = 'B')
- sparse_caption.utils.misc.replace_from_right(string: str, old: str, new: str, count: int = - 1)
String replacement from right to left. https://stackoverflow.com/a/3679215
- sparse_caption.utils.misc.str_to_bool(input_string: str)
- sparse_caption.utils.misc.str_to_none(input_string: str)
- sparse_caption.utils.misc.str_to_sequence(input_string: str)
- sparse_caption.utils.misc.time_function(func)
sparse_caption.utils.model_utils module
@author: jiahuei, ruotianluo
- sparse_caption.utils.model_utils.clones(module, N)
Produce N identical layers.
- sparse_caption.utils.model_utils.count_nonzero(tensor)
- sparse_caption.utils.model_utils.densify_state_dict(state_dict)
- sparse_caption.utils.model_utils.length_average(length, logprobs, alpha=0.0)
Returns the average probability of tokens in a sequence.
- sparse_caption.utils.model_utils.length_wu(length, logprobs, alpha=0.0)
NMT length re-ranking score from “Google’s Neural Machine Translation System” :cite:`wu2016google`.
- sparse_caption.utils.model_utils.map_recursive(x: Any, func: Callable)
Applies func to elements of x recursively. :param x: An item or a potentially nested structure of tuple, list or dict. :param func: A single argument function.
- Returns
The same x but with func applied.
- sparse_caption.utils.model_utils.map_to_cuda(x: Any)
- sparse_caption.utils.model_utils.pack_wrapper(module, att_feats, att_masks)
- sparse_caption.utils.model_utils.pad_unsort_packed_sequence(inputs, inv_ix)
- sparse_caption.utils.model_utils.penalty_builder(penalty_config)
- sparse_caption.utils.model_utils.reorder_beam(tensor: torch.Tensor, beam_idx: torch.Tensor, beam_dim: int = 0)
- sparse_caption.utils.model_utils.repeat_tensors(n, x, dim=0)
For a tensor of size Bx…, we repeat it n times, and make it Bnx… For collections, do nested repeat
- sparse_caption.utils.model_utils.requires_grad(model, flag=True)
- sparse_caption.utils.model_utils.sequence_from_numpy(sequence)
- sparse_caption.utils.model_utils.set_seed(seed: int)
- sparse_caption.utils.model_utils.sort_pack_padded_sequence(inputs, lengths)
- sparse_caption.utils.model_utils.split_tensors(n, x)
- sparse_caption.utils.model_utils.to_cuda(x: Any)
sparse_caption.utils.natural_sort module
Created on Thu Jun 28 15:36:44 2018
@author: jiahuei
- sparse_caption.utils.natural_sort.atoi(text)
- sparse_caption.utils.natural_sort.natural_keys(text)
alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy’s implementation in the comments)
sparse_caption.utils.optim module
Created on 16 Sep 2020 15:01:35 @author: jiahuei
- class sparse_caption.utils.optim.CosineOpt(optimizer, max_train_step, learning_rate_init, learning_rate_min)
Bases:
sparse_caption.utils.optim.RateOptOptim wrapper that implements rate.
- rate()
Implement lrate above
- class sparse_caption.utils.optim.NoamOpt(optimizer, model_size, factor, warmup)
Bases:
sparse_caption.utils.optim.RateOptOptim wrapper that implements rate.
- rate()
Implement lrate above
- class sparse_caption.utils.optim.RateOpt
Bases:
objectOptim wrapper that implements rate.
- step(step=None, epoch=None)
Update parameters and rate
- class sparse_caption.utils.optim.StepLROpt(optimizer, learning_rate_init, learning_rate_decay_start, learning_rate_decay_every, learning_rate_decay_rate)
Bases:
sparse_caption.utils.optim.RateOptOptim wrapper that implements rate.
- rate()
Implement lrate above
- sparse_caption.utils.optim.build_optimizer(params, config)
- sparse_caption.utils.optim.clip_gradient(optimizer, grad_clip)
- sparse_caption.utils.optim.get_optim(parameters, config)
sparse_caption.utils.training module
Created on 01 May 2020 15:00:54 @author: jiahuei
https://github.com/huggingface/transformers/blob/v2.9.0/examples/lightning_base.py
- class sparse_caption.utils.training.TrainingModule(config: sparse_caption.utils.config.Config)
Bases:
objectBase class for training and evaluation.
- ALL_METRICS = ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'SPICE']
- SCST_BASELINE = ['greedy', 'sample']
- SCST_SAMPLE = ['beam_search', 'random']
- static add_argparse_args(parser: Union[argparse._ArgumentGroup, argparse.ArgumentParser])
- checkpoint_path
alias of
str
- collate_fn: Dict[str, Callable]
- compute_scst_loss(model_inputs, gts, loss_fn)
- classmethod eval_model(state_dict, config, split='test')
- eval_on_split(loader, split)
- get_dataloader(split: str, collate_fn: Callable, generation_mode: bool = False)
- maybe_load_checkpoint(strict=True)
- model: torch.nn.modules.module.Module
- optimizer: torch.optim.optimizer.Optimizer
- prepare()
- scst_scorer: sparse_caption.scst.scorers.CaptionScorer
- test_dataloader()
- tokenizer: sparse_caption.tokenizer.Tokenizer
- train_dataloader()
- val_dataloader()
Module contents
Created on 01 Aug 2020 18:12:28 @author: jiahuei