JarrettYe commited on
Commit
c81248e
1 Parent(s): 3b21588

update to FSRS-5 & support time unit

Browse files
Files changed (2) hide show
  1. app.py +31 -7
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,17 +1,34 @@
1
  from typing import List, Tuple
2
- import gradio as gr
3
  import os
4
  import sys
5
 
6
  if os.environ.get("DEV_MODE"):
7
  # for local development
8
  sys.path.insert(0, os.path.abspath("../fsrs-optimizer/src/fsrs_optimizer/"))
9
- from fsrs_optimizer import Optimizer, DEFAULT_WEIGHT, FSRS, lineToTensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def interface_func(
13
  weights: str, ratings: str, delta_ts: str, request_retention: float
14
- ) -> str:
15
  weights = weights.replace("[", "").replace("]", "")
16
  optimizer = Optimizer()
17
  optimizer.w = list(map(lambda x: float(x.strip()), weights.split(",")))
@@ -20,6 +37,12 @@ def interface_func(
20
  )
21
  default_preview = optimizer.preview(request_retention)
22
  if delta_ts != "":
 
 
 
 
 
 
23
  s_history, d_history = memory_state_sequence(ratings, delta_ts, optimizer.w)
24
  return (
25
  test_sequence,
@@ -31,13 +54,14 @@ def interface_func(
31
 
32
  def memory_state_sequence(
33
  r_history: str, t_history: str, weights: List[float]
34
- ) -> Tuple[List[float], List[float]]:
35
  fsrs = FSRS(weights)
36
  line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0]).unsqueeze(1)
37
  outputs, _ = fsrs(line_tensor)
38
  stabilities, difficulties = outputs.transpose(0, 1)[0].transpose(0, 1)
39
- return map(lambda x: str(round(x, 2)), stabilities.tolist()), map(
40
- lambda x: str(round(x, 2)), difficulties.tolist()
 
41
  )
42
 
43
 
@@ -47,7 +71,7 @@ iface = gr.Interface(
47
  gr.Textbox(
48
  label="weights",
49
  lines=1,
50
- value=str(DEFAULT_WEIGHT)[1:-1],
51
  ),
52
  gr.Textbox(label="ratings", lines=1, value="3,3,3,3,1,3,3"),
53
  gr.Textbox(label="delta_ts (requried by state history)", lines=1, value=""),
 
1
  from typing import List, Tuple
2
+ import gradio as gr # type: ignore
3
  import os
4
  import sys
5
 
6
  if os.environ.get("DEV_MODE"):
7
  # for local development
8
  sys.path.insert(0, os.path.abspath("../fsrs-optimizer/src/fsrs_optimizer/"))
9
+ from fsrs_optimizer import Optimizer, DEFAULT_PARAMETER, FSRS, lineToTensor # type: ignore
10
+
11
+
12
+ def convert_delta_ts(delta_ts: str) -> List[str]:
13
+ delta_ts_list = delta_ts.replace(" ", "").split(",")
14
+ converted_delta_ts = []
15
+ for dt in delta_ts_list:
16
+ if dt.endswith("d"):
17
+ converted_delta_ts.append(dt[:-1])
18
+ elif dt.endswith("m"):
19
+ value = float(dt[:-1]) * 30
20
+ converted_delta_ts.append(str(value))
21
+ elif dt.endswith("y"):
22
+ value = float(dt[:-1]) * 365
23
+ converted_delta_ts.append(str(value))
24
+ else:
25
+ converted_delta_ts.append(dt)
26
+ return converted_delta_ts
27
 
28
 
29
  def interface_func(
30
  weights: str, ratings: str, delta_ts: str, request_retention: float
31
+ ) -> Tuple[str, str, str]:
32
  weights = weights.replace("[", "").replace("]", "")
33
  optimizer = Optimizer()
34
  optimizer.w = list(map(lambda x: float(x.strip()), weights.split(",")))
 
37
  )
38
  default_preview = optimizer.preview(request_retention)
39
  if delta_ts != "":
40
+ ratings_list = ratings.replace(" ", "").split(",")
41
+ delta_ts_list = convert_delta_ts(delta_ts)
42
+ min_len = min(len(ratings_list), len(delta_ts_list))
43
+ ratings = ",".join(ratings_list[:min_len])
44
+ delta_ts = ",".join(delta_ts_list[:min_len])
45
+
46
  s_history, d_history = memory_state_sequence(ratings, delta_ts, optimizer.w)
47
  return (
48
  test_sequence,
 
54
 
55
  def memory_state_sequence(
56
  r_history: str, t_history: str, weights: List[float]
57
+ ) -> Tuple[List[str], List[str]]:
58
  fsrs = FSRS(weights)
59
  line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0]).unsqueeze(1)
60
  outputs, _ = fsrs(line_tensor)
61
  stabilities, difficulties = outputs.transpose(0, 1)[0].transpose(0, 1)
62
+ return (
63
+ list(map(lambda x: str(round(x, 2)), stabilities.tolist())),
64
+ list(map(lambda x: str(round(x, 2)), difficulties.tolist())),
65
  )
66
 
67
 
 
71
  gr.Textbox(
72
  label="weights",
73
  lines=1,
74
+ value=str(DEFAULT_PARAMETER)[1:-1],
75
  ),
76
  gr.Textbox(label="ratings", lines=1, value="3,3,3,3,1,3,3"),
77
  gr.Textbox(label="delta_ts (requried by state history)", lines=1, value=""),
requirements.txt CHANGED
@@ -1 +1 @@
1
- FSRS-Optimizer==4.28.2
 
1
+ FSRS-Optimizer==5.3.0