Dongfu Jiang commited on
Commit
342b809
1 Parent(s): 8841a8a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -5
README.md CHANGED
@@ -38,15 +38,17 @@ cand2_prefix = "<|candidate2|>"
38
  inputs = ["hello!", "I love you!"]
39
  candidates_A = ["hi!", "I hate you!"]
40
  candidates_B = ["f**k off!", "I love you, too!"]
41
- def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str]):
42
  ids = []
43
  assert len(sources) == len(candidate1s) == len(candidate2s)
 
44
  for i in range(len(sources)):
45
- source_ids = tokenizer.encode(source_prefix + sources[i])
46
- candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i])
47
- candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i])
 
48
  ids.append(source_ids + candidate1_ids + candidate2_ids)
49
- encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt")
50
  return encodings
51
 
52
  encodings = tokenize_pair(inputs, candidates_A, candidates_B)
 
38
  inputs = ["hello!", "I love you!"]
39
  candidates_A = ["hi!", "I hate you!"]
40
  candidates_B = ["f**k off!", "I love you, too!"]
41
+ def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=1224, candidate_max_length=412):
42
  ids = []
43
  assert len(sources) == len(candidate1s) == len(candidate2s)
44
+ max_length = source_max_length + 2 * candidate_max_length
45
  for i in range(len(sources)):
46
+ source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True)
47
+ candidate_max_length = (max_length - len(source_ids)) // 2
48
+ candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True)
49
+ candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True)
50
  ids.append(source_ids + candidate1_ids + candidate2_ids)
51
+ encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding=True, max_length=max_length)
52
  return encodings
53
 
54
  encodings = tokenize_pair(inputs, candidates_A, candidates_B)