additional config
This commit is contained in:
parent
3cbed8144e
commit
bdb7f5ef7e
@ -15,14 +15,12 @@ from sconf import Config
|
||||
|
||||
def main(config):
|
||||
|
||||
# max_length = 768
|
||||
# image_size = [1920, 2560]
|
||||
# config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
|
||||
# config_vision.encoder.image_size = image_size # (height, width)
|
||||
# config_vision.decoder.max_length = max_length
|
||||
config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
|
||||
config_vision.encoder.image_size = [1920, 2560] # (height, width)
|
||||
config_vision.decoder.max_length = 768
|
||||
|
||||
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config_vision)
|
||||
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model.eval()
|
||||
|
Loading…
Reference in New Issue
Block a user