Petr Tsvetkov
commited on
Commit
β’
347f566
1
Parent(s):
5bd86a2
Fix the synthetic data generation pipeline
Browse files
api_wrappers/grazie_wrapper.py
CHANGED
@@ -32,7 +32,7 @@ def llm_request(prompt):
|
|
32 |
|
33 |
while output is None:
|
34 |
try:
|
35 |
-
output =
|
36 |
chat=ChatPrompt()
|
37 |
.add_system("You are a helpful assistant.")
|
38 |
.add_user(prompt),
|
|
|
32 |
|
33 |
while output is None:
|
34 |
try:
|
35 |
+
output = client.chat(
|
36 |
chat=ChatPrompt()
|
37 |
.add_system("You are a helpful assistant.")
|
38 |
.add_user(prompt),
|
dataset_statistics.py
CHANGED
@@ -9,10 +9,7 @@ from scipy.stats import stats
|
|
9 |
import config
|
10 |
|
11 |
|
12 |
-
def
|
13 |
-
start_msg = row["commit_msg_start"]
|
14 |
-
end_msg = row["commit_msg_end"]
|
15 |
-
|
16 |
edit_ops = Levenshtein.editops(start_msg, end_msg)
|
17 |
n_deletes = sum([1 if op == 'delete' else 0 for op, _, _ in edit_ops])
|
18 |
n_inserts = sum([1 if op == 'insert' else 0 for op, _, _ in edit_ops])
|
@@ -32,12 +29,18 @@ def get_statistics(row):
|
|
32 |
"changes_norm": n_changes / len(end_msg),
|
33 |
|
34 |
"lendiff": abs(len(start_msg) - len(end_msg)),
|
35 |
-
"editdist": row["editdist_related"]
|
36 |
}
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def get_statistics_for_df(df: pd.DataFrame):
|
40 |
-
stats = [
|
41 |
df.iterrows()]
|
42 |
|
43 |
assert len(stats) > 0
|
|
|
9 |
import config
|
10 |
|
11 |
|
12 |
+
def get_statistics_for_sample(start_msg, end_msg, row=None):
|
|
|
|
|
|
|
13 |
edit_ops = Levenshtein.editops(start_msg, end_msg)
|
14 |
n_deletes = sum([1 if op == 'delete' else 0 for op, _, _ in edit_ops])
|
15 |
n_inserts = sum([1 if op == 'insert' else 0 for op, _, _ in edit_ops])
|
|
|
29 |
"changes_norm": n_changes / len(end_msg),
|
30 |
|
31 |
"lendiff": abs(len(start_msg) - len(end_msg)),
|
32 |
+
"editdist": row["editdist_related"] if row is not None else Levenshtein.distance(start_msg, end_msg),
|
33 |
}
|
34 |
|
35 |
|
36 |
+
def get_statistics_for_row(row):
|
37 |
+
start_msg = row["commit_msg_start"]
|
38 |
+
end_msg = row["commit_msg_end"]
|
39 |
+
return get_statistics_for_sample(start_msg, end_msg, row=row)
|
40 |
+
|
41 |
+
|
42 |
def get_statistics_for_df(df: pd.DataFrame):
|
43 |
+
stats = [get_statistics_for_row(row) for _, row in
|
44 |
df.iterrows()]
|
45 |
|
46 |
assert len(stats) > 0
|
generation_steps/synthetic_end_to_start.py
CHANGED
@@ -4,8 +4,8 @@ import pandas as pd
|
|
4 |
from tqdm import tqdm
|
5 |
|
6 |
import config
|
7 |
-
import generate_annotated_diffs
|
8 |
import dataset_statistics
|
|
|
9 |
from api_wrappers import grazie_wrapper, hf_data_loader
|
10 |
from generation_steps import examples
|
11 |
|
@@ -49,9 +49,8 @@ def generate_start_msg(end_msg, diff):
|
|
49 |
for i in range(GENERATION_ATTEMPTS):
|
50 |
start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
|
51 |
|
52 |
-
stats =
|
53 |
-
|
54 |
-
end_msg))
|
55 |
if stats["insertions"] < REL_INSERTIONS_THRESHOLD:
|
56 |
return start_msg_pred
|
57 |
else:
|
|
|
4 |
from tqdm import tqdm
|
5 |
|
6 |
import config
|
|
|
7 |
import dataset_statistics
|
8 |
+
import generate_annotated_diffs
|
9 |
from api_wrappers import grazie_wrapper, hf_data_loader
|
10 |
from generation_steps import examples
|
11 |
|
|
|
49 |
for i in range(GENERATION_ATTEMPTS):
|
50 |
start_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
|
51 |
|
52 |
+
stats = dataset_statistics.get_statistics_for_sample(start_msg=start_msg_pred, end_msg=end_msg,)
|
53 |
+
|
|
|
54 |
if stats["insertions"] < REL_INSERTIONS_THRESHOLD:
|
55 |
return start_msg_pred
|
56 |
else:
|
generation_steps/synthetic_start_to_end.py
CHANGED
@@ -2,7 +2,6 @@ import pandas as pd
|
|
2 |
from tqdm import tqdm
|
3 |
|
4 |
import config
|
5 |
-
import generate_annotated_diffs
|
6 |
import dataset_statistics
|
7 |
from api_wrappers import grazie_wrapper
|
8 |
from generation_steps import examples
|
@@ -47,9 +46,7 @@ def generate_end_msg(start_msg, diff):
|
|
47 |
for i in range(GENERATION_ATTEMPTS):
|
48 |
end_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
|
49 |
|
50 |
-
stats =
|
51 |
-
annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg,
|
52 |
-
end_msg_pred))
|
53 |
if stats["deletions"] < REL_DELETIONS_THRESHOLD:
|
54 |
return end_msg_pred
|
55 |
else:
|
|
|
2 |
from tqdm import tqdm
|
3 |
|
4 |
import config
|
|
|
5 |
import dataset_statistics
|
6 |
from api_wrappers import grazie_wrapper
|
7 |
from generation_steps import examples
|
|
|
46 |
for i in range(GENERATION_ATTEMPTS):
|
47 |
end_msg_pred = grazie_wrapper.generate_for_prompt(prompt)
|
48 |
|
49 |
+
stats = dataset_statistics.get_statistics_for_sample(start_msg=start_msg, end_msg=end_msg_pred, )
|
|
|
|
|
50 |
if stats["deletions"] < REL_DELETIONS_THRESHOLD:
|
51 |
return end_msg_pred
|
52 |
else:
|